From 2761a257aeedbc4806125675e3c000684ac24153 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nguy=E1=BB=85n=20M=E1=BA=A1nh=20H=C3=B9ng?= Date: Thu, 11 Sep 2025 02:57:19 +0700 Subject: [PATCH 1/2] Update Code: replaced outdated Keras, added requirements.txt --- .gitignore | 11 +++++++++-- README.md | 7 ++----- neural.py | 46 ++++++++++++++++++++++++++++------------------ requirements.txt | 31 +++++++++++++++++++++++++++++++ train_network.py | 8 +++++--- 5 files changed, 75 insertions(+), 28 deletions(-) create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore index 184c2dd..c4bb625 100644 --- a/.gitignore +++ b/.gitignore @@ -84,11 +84,18 @@ celerybeat-schedule .env # virtualenv -venv/ -ENV/ +.venv* +venv* +ENV* +.env* +*.venv # Spyder project settings .spyderproject # Rope project settings .ropeproject + +# --- +*.h5 +load_model.py \ No newline at end of file diff --git a/README.md b/README.md index 4fa9040..c6278da 100644 --- a/README.md +++ b/README.md @@ -5,12 +5,9 @@ See: https://www.nature.com/articles/s41598-017-11266-1 Make your own decoder with: ``` -train_network.py 5 output.model \ - --onthefly 10000000 50000 \ - --Xstab --Zstab \ - --epochs 10 --prob 0.9 \ - --learningrate .000001 --normcenterstab --layers 4 4 4 4 4 4 4 +train_network.py 5 output.model --onthefly 10000000 50000 --Xstab --Zstab --epochs 10 --prob 0.9 --learningrate .000001 --normcenterstab --layers 4 4 4 4 4 4 4 ``` + Test a network by adding the `--eval` flag. See `train_network.py -h` for description of each flag. diff --git a/neural.py b/neural.py index cb37cdd..856ca95 100644 --- a/neural.py +++ b/neural.py @@ -5,11 +5,11 @@ from keras.models import Sequential from keras.layers import Dense, Dropout, Activation from keras.optimizers import Nadam -from keras.objectives import binary_crossentropy -from keras.layers.normalization import BatchNormalization +from keras.losses import binary_crossentropy +from keras.layers import BatchNormalization import tensorflow as tf -F = lambda _: K.cast(_, 'float32') # TODO XXX there must be a better way to calculate mean than this cast-first approach +F = lambda _: tf.cast(_, 'float32') # TODO XXX there must be a better way to calculate mean than this cast-first approach class CodeCosts: @@ -20,52 +20,62 @@ def __init__(self, L, code, Z, X, normcentererr_p=None): code = code(L) H = code.H(Z,X) E = code.E(Z,X) - self.H = K.variable(value=H) # TODO should be sparse - self.E = K.variable(value=E) # TODO should be sparse + self.H = tf.Variable(initial_value=H, trainable=False, dtype=tf.float32) # TODO should be sparse + self.E = tf.Variable(initial_value=E, trainable=False, dtype=tf.float32) # TODO should be sparse self.p = normcentererr_p def exact_reversal(self, y_true, y_pred): "Fraction exactly predicted qubit flips." if self.p: y_pred = undo_normcentererr(y_pred, self.p) y_true = undo_normcentererr(y_true, self.p) - return K.mean(F(K.all(K.equal(y_true, K.round(y_pred)), axis=-1))) + return tf.reduce_mean(F(tf.reduce_all(tf.equal(y_true, tf.round(y_pred)), axis=-1))) def non_triv_stab_expanded(self, y_true, y_pred): "Whether the stabilizer after correction is not trivial." if self.p: y_pred = undo_normcentererr(y_pred, self.p) y_true = undo_normcentererr(y_true, self.p) - return K.any(K.dot(self.H, K.transpose((K.round(y_pred)+y_true)%2))%2, axis=0) + # Cast to same dtype to avoid type mismatch + y_pred_rounded = tf.cast(tf.round(y_pred), tf.float32) + y_true_cast = tf.cast(y_true, tf.float32) + correction = tf.cast((y_pred_rounded + y_true_cast) % 2, tf.float32) + return tf.reduce_any(tf.cast(tf.matmul(self.H, tf.transpose(correction)) % 2, tf.bool), axis=0) def logic_error_expanded(self, y_true, y_pred): "Whether there is a logical error after correction." if self.p: y_pred = undo_normcentererr(y_pred, self.p) y_true = undo_normcentererr(y_true, self.p) - return K.any(K.dot(self.E, K.transpose((K.round(y_pred)+y_true)%2))%2, axis=0) + # Cast to same dtype to avoid type mismatch + y_pred_rounded = tf.cast(tf.round(y_pred), tf.float32) + y_true_cast = tf.cast(y_true, tf.float32) + correction = tf.cast((y_pred_rounded + y_true_cast) % 2, tf.float32) + return tf.reduce_any(tf.cast(tf.matmul(self.E, tf.transpose(correction)) % 2, tf.bool), axis=0) def triv_stab(self, y_true, y_pred): "Fraction trivial stabilizer after corrections." - return 1-K.mean(F(self.non_triv_stab_expanded(y_true, y_pred))) + return 1-tf.reduce_mean(F(self.non_triv_stab_expanded(y_true, y_pred))) def no_error(self, y_true, y_pred): "Fraction no logical errors after corrections." - return 1-K.mean(F(self.logic_error_expanded(y_true, y_pred))) + return 1-tf.reduce_mean(F(self.logic_error_expanded(y_true, y_pred))) def triv_no_error(self, y_true, y_pred): "Fraction with trivial stabilizer and no error." # TODO XXX Those casts (the F function) should not be there! This should be logical operations triv_stab = 1 - F(self.non_triv_stab_expanded(y_true, y_pred)) no_err = 1 - F(self.logic_error_expanded(y_true, y_pred)) - return K.mean(no_err*triv_stab) + return tf.reduce_mean(no_err*triv_stab) def e_binary_crossentropy(self, y_true, y_pred): if self.p: y_pred = undo_normcentererr(y_pred, self.p) y_true = undo_normcentererr(y_true, self.p) - return K.mean(K.binary_crossentropy(y_pred, y_true), axis=-1) + return tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_true, y_pred), axis=-1) def s_binary_crossentropy(self, y_true, y_pred): if self.p: y_pred = undo_normcentererr(y_pred, self.p) y_true = undo_normcentererr(y_true, self.p) - s_true = K.dot(y_true, K.transpose(self.H))%2 + # Cast to avoid type mismatch + y_true_cast = tf.cast(y_true, tf.float32) + s_true = tf.cast(tf.matmul(y_true_cast, tf.transpose(self.H)) % 2, tf.float32) twopminusone = 2*y_pred-1 - s_pred = ( 1 - tf.real(K.exp(K.dot(K.log(tf.cast(twopminusone, tf.complex64)), tf.cast(K.transpose(self.H), tf.complex64)))) ) / 2 - return K.mean(K.binary_crossentropy(s_pred, s_true), axis=-1) + s_pred = ( 1 - tf.math.real(tf.exp(tf.matmul(tf.math.log(tf.cast(twopminusone, tf.complex64)), tf.cast(tf.transpose(self.H), tf.complex64)))) ) / 2 + return tf.reduce_mean(tf.keras.losses.binary_crossentropy(s_true, s_pred), axis=-1) def se_binary_crossentropy(self, y_true, y_pred): return 2./3.*self.e_binary_crossentropy(y_true, y_pred) + 1./3.*self.s_binary_crossentropy(y_true, y_pred) @@ -94,7 +104,7 @@ def create_model(L, hidden_sizes=[4], hidden_act='tanh', act='sigmoid', loss='bi 's_binary_crossentropy':c.s_binary_crossentropy, 'se_binary_crossentropy':c.se_binary_crossentropy} model.compile(loss=losses.get(loss,loss), - optimizer=Nadam(lr=learning_rate), + optimizer=Nadam(learning_rate=learning_rate), metrics=[c.triv_no_error, c.e_binary_crossentropy, c.s_binary_crossentropy] ) return model @@ -134,8 +144,8 @@ def data_generator(H, out_dimZ, out_dimX, in_dim, p, batch_size=512, size=None, flips = do_normcentererr(flips, p) yield (stabs, flips) c += 1 - if size and c==size: - raise StopIteration + if size and c >= size: + return def do_normcenterstab(stabs, p): avg = (1-p)*2/3 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..fab9b0d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,31 @@ +# Neural Network Decoders for Quantum Error Correcting Codes +# Compatible with Python 3.10 or 3.11 (TensorFlow compatibility issues with 3.12) + +# Core ML/DL frameworks +tensorflow>=2.16.0,<2.21.0 +keras>=3.0.0 + +# Scientific computing +numpy>=1.21.0,<2.0.0 +scipy>=1.9.0 + +# Graph algorithms (for MWPM) +networkx>=2.8 + +# Progress bars +tqdm>=4.64.0 + +# Optional: Jupyter notebook support +jupyter>=1.0.0 +ipython>=8.0.0 + +# Optional: Plotting (mentioned in codes.py) +matplotlib>=3.5.0 + +# Optional: For better performance with large arrays +# numba>=0.56.0 + +# Development dependencies (optional) +# pytest>=7.0.0 +# black>=22.0.0 +# flake8>=5.0.0 diff --git a/train_network.py b/train_network.py index 9764131..bc97c88 100644 --- a/train_network.py +++ b/train_network.py @@ -110,9 +110,11 @@ normcenterstab=args.normcenterstab, normcentererr=args.normcentererr) val = data_generator(H, out_dimZ, out_dimX, in_dim, args.prob, args.batch, normcenterstab=args.normcenterstab, normcentererr=args.normcentererr) - hist = model.fit_generator(dat, args.onthefly[0]//args.batch, args.epochs, - validation_data=val, validation_steps=args.onthefly[1]//args.batch) - model.save_weights(args.out) + hist = model.fit(dat, steps_per_epoch=args.onthefly[0]//args.batch, epochs=args.epochs, + validation_data=val, validation_steps=args.onthefly[1]//args.batch) + # Add .weights.h5 extension for Keras 3.x compatibility + weights_filename = args.out + '.weights.h5' if not args.out.endswith('.weights.h5') else args.out + model.save_weights(weights_filename) with open(args.out+'.log', 'w') as f: f.write(str((hist.params, hist.history))) if args.eval: From baa19230e087a7353aa67e9c326dd6c7de98a99b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nguy=E1=BB=85n=20M=E1=BA=A1nh=20H=C3=B9ng?= Date: Thu, 11 Sep 2025 05:06:43 +0700 Subject: [PATCH 2/2] format code --- codes.py | 266 +++++++++++++++++++++----------------- find_threshold.py | 13 +- generate_training_data.py | 6 +- neural.py | 87 ++++++++----- train_network.py | 50 +++---- 5 files changed, 244 insertions(+), 178 deletions(-) diff --git a/codes.py b/codes.py index 43fb265..a05837e 100644 --- a/codes.py +++ b/codes.py @@ -40,34 +40,41 @@ class ToricCode: X10--Q20--X11--Q21--X12... . . . ''' + def __init__(self, L): '''Toric code of ``2 L**2`` physical qubits and distance ``L``.''' self.L = L - self.Xflips = np.zeros((2*L,L), dtype=np.dtype('b')) # qubits where an X error occured - self.Zflips = np.zeros((2*L,L), dtype=np.dtype('b')) # qubits where a Z error occured - self._Xstab = np.empty((L,L), dtype=np.dtype('b')) - self._Zstab = np.empty((L,L), dtype=np.dtype('b')) + # qubits where an X error occured + self.Xflips = np.zeros((2*L, L), dtype=np.dtype('b')) + # qubits where a Z error occured + self.Zflips = np.zeros((2*L, L), dtype=np.dtype('b')) + self._Xstab = np.empty((L, L), dtype=np.dtype('b')) + self._Zstab = np.empty((L, L), dtype=np.dtype('b')) @property def flatXflips2Zstab(self): L = self.L _flatXflips2Zstab = np.zeros((L**2, 2*L**2), dtype=np.dtype('b')) - for i, j in itertools.product(range(L),range(L)): - _flatXflips2Zstab[i*L+j, (2*i )%(2*L)*L+(j )%L] = 1 - _flatXflips2Zstab[i*L+j, (2*i+1)%(2*L)*L+(j )%L] = 1 - _flatXflips2Zstab[i*L+j, (2*i+2)%(2*L)*L+(j )%L] = 1 - _flatXflips2Zstab[i*L+j, (2*i+1)%(2*L)*L+(j+1)%L] = 1 + for i, j in itertools.product(range(L), range(L)): + _flatXflips2Zstab[i*L+j, (2*i) % (2*L)*L+(j) % L] = 1 + _flatXflips2Zstab[i*L+j, (2*i+1) % (2*L)*L+(j) % L] = 1 + _flatXflips2Zstab[i*L+j, (2*i+2) % (2*L)*L+(j) % L] = 1 + _flatXflips2Zstab[i*L+j, (2*i+1) % (2*L)*L+(j+1) % L] = 1 return _flatXflips2Zstab @property def flatZflips2Xstab(self): L = self.L _flatZflips2Xstab = np.zeros((L**2, 2*L**2), dtype=np.dtype('b')) - for i, j in itertools.product(range(L),range(L)): - _flatZflips2Xstab[(i+1)%L*L+(j+1)%L, (2*i+1)%(2*L)*L+(j+1)%L] = 1 - _flatZflips2Xstab[(i+1)%L*L+(j+1)%L, (2*i+2)%(2*L)*L+(j )%L] = 1 - _flatZflips2Xstab[(i+1)%L*L+(j+1)%L, (2*i+3)%(2*L)*L+(j+1)%L] = 1 - _flatZflips2Xstab[(i+1)%L*L+(j+1)%L, (2*i+2)%(2*L)*L+(j+1)%L] = 1 + for i, j in itertools.product(range(L), range(L)): + _flatZflips2Xstab[(i+1) % L*L+(j+1) % + L, (2*i+1) % (2*L)*L+(j+1) % L] = 1 + _flatZflips2Xstab[(i+1) % L*L+(j+1) % + L, (2*i+2) % (2*L)*L+(j) % L] = 1 + _flatZflips2Xstab[(i+1) % L*L+(j+1) % + L, (2*i+3) % (2*L)*L+(j+1) % L] = 1 + _flatZflips2Xstab[(i+1) % L*L+(j+1) % + L, (2*i+2) % (2*L)*L+(j+1) % L] = 1 return _flatZflips2Xstab @property @@ -75,8 +82,8 @@ def flatXflips2Zerr(self): L = self.L _flatXflips2Zerr = np.zeros((2, 2*L**2), dtype=np.dtype('b')) for k in range(L): - _flatXflips2Zerr[0, (2*k+1)%(2*L)*L+(0 )%L] = 1 - _flatXflips2Zerr[1, (2*0 )%(2*L)*L+(k )%L] = 1 + _flatXflips2Zerr[0, (2*k+1) % (2*L)*L+(0) % L] = 1 + _flatXflips2Zerr[1, (2*0) % (2*L)*L+(k) % L] = 1 return _flatXflips2Zerr @property @@ -84,8 +91,8 @@ def flatZflips2Xerr(self): L = self.L _flatZflips2Xerr = np.zeros((2, 2*L**2), dtype=np.dtype('b')) for k in range(L): - _flatZflips2Xerr[0, (2*0+1)%(2*L)*L+(k )%L] = 1 - _flatZflips2Xerr[1, (2*k )%(2*L)*L+(0 )%L] = 1 + _flatZflips2Xerr[0, (2*0+1) % (2*L)*L+(k) % L] = 1 + _flatZflips2Xerr[1, (2*k) % (2*L)*L+(0) % L] = 1 return _flatZflips2Xerr def H(self, Z=True, X=False): @@ -110,36 +117,40 @@ def Zstabilizer(self): '''Return all measurements of the Z stabilizer with ``true`` marking non-trivial.''' stab = self._Zstab X = self.Xflips - stab[0:-1,0:-1] = X[0:-2:2,0:-1:] ^ X[1:-1:2,0:-1:] ^ X[2::2,0:-1:] ^ X[1:-1:2,1::] - stab[ -1,0:-1] = X[ -2 ,0:-1:] ^ X[ -1 ,0:-1:] ^ X[ 0 ,0:-1:] ^ X[ -1 ,1::] - stab[0:-1, -1] = X[0:-2:2, -1 ] ^ X[1:-1:2, -1 ] ^ X[2::2, -1 ] ^ X[1:-1:2, 0] - stab[ -1, -1] = X[ -2 , -1 ] ^ X[ -1 , -1 ] ^ X[ 0 , -1 ] ^ X[ -1 , 0] + stab[0:-1, 0:-1] = X[0:-2:2, 0:-1:] ^ X[1:-1:2, + 0:-1:] ^ X[2::2, 0:-1:] ^ X[1:-1:2, 1::] + stab[-1, 0:-1] = X[-2, 0:-1:] ^ X[-1, 0:-1:] ^ X[0, 0:-1:] ^ X[-1, 1::] + stab[0:-1, -1] = X[0:-2:2, -1] ^ X[1:- + 1:2, -1] ^ X[2::2, -1] ^ X[1:-1:2, 0] + stab[-1, -1] = X[-2, -1] ^ X[-1, -1] ^ X[0, -1] ^ X[-1, 0] return stab def Xstabilizer(self): '''Return all measurements of the X stabilizer with ``true`` marking non-trivial.''' stab = self._Xstab Z = self.Zflips - stab[1:,1:] = Z[1:-2:2,1:] ^ Z[2:-1:2,0:-1] ^ Z[3::2,1:] ^ Z[2:-1:2,1:] - stab[0 ,1:] = Z[ -1 ,1:] ^ Z[ 0 ,0:-1] ^ Z[ 1 ,1:] ^ Z[ 0 ,1:] - stab[1:,0 ] = Z[1:-2:2,0 ] ^ Z[2:-1:2, -1] ^ Z[3::2,0 ] ^ Z[2:-1:2,0 ] - stab[0 ,0 ] = Z[ -1 ,0 ] ^ Z[ 0 , -1] ^ Z[ 1 ,0 ] ^ Z[ 0 ,0 ] + stab[1:, 1:] = Z[1:-2:2, 1:] ^ Z[2:-1:2, + 0:-1] ^ Z[3::2, 1:] ^ Z[2:-1:2, 1:] + stab[0, 1:] = Z[-1, 1:] ^ Z[0, 0:-1] ^ Z[1, 1:] ^ Z[0, 1:] + stab[1:, 0] = Z[1:-2:2, 0] ^ Z[2:-1:2, -1] ^ Z[3::2, 0] ^ Z[2:-1:2, 0] + stab[0, 0] = Z[-1, 0] ^ Z[0, -1] ^ Z[1, 0] ^ Z[0, 0] return stab def _plot_flips(self, s, flips_yx, label): '''Given an array of yx coordiante plot qubit flips on subplot ``s``.''' - if not len(flips_yx): return + if not len(flips_yx): + return y, x = flips_yx x = x.astype(float) - x[y%2==0] += 0.5 + x[y % 2 == 0] += 0.5 x = np.concatenate([x, x-self.L, x]) y = np.concatenate([y/2., y/2., y/2.-self.L]) - s.plot(x, y,'o', ms=50/self.L, label=label) + s.plot(x, y, 'o', ms=50/self.L, label=label) def plot(self, legend=True, stabs=True): '''Plot the state of the system (including stabilizers).''' - f = plt.figure(figsize=(5,5)) - s = f.add_subplot(1,1,1) + f = plt.figure(figsize=(5, 5)) + s = f.add_subplot(1, 1, 1) self._plot_legend = legend self._plot_flips(s, self.Xflips.nonzero(), label='X') @@ -150,15 +161,15 @@ def plot(self, legend=True, stabs=True): y, x = self.Zstabilizer().nonzero() x = np.concatenate([x+0.5, x+0.5-self.L, x+0.5, x+0.5-self.L]) y = np.concatenate([y+0.5, y+0.5, y+0.5-self.L, y+0.5-self.L]) - s.plot(x,y,'s', mew=0, ms=190/self.L, label='plaq') + s.plot(x, y, 's', mew=0, ms=190/self.L, label='plaq') y, x = self.Xstabilizer().nonzero() s.plot(x, y, '+', mew=100/self.L, ms=200/self.L, label='star') - s.set_xticks(range(0,self.L)) - s.set_yticks(range(0,self.L)) - s.set_xlim(-0.6,self.L-0.4) - s.set_ylim(-0.6,self.L-0.4) + s.set_xticks(range(0, self.L)) + s.set_yticks(range(0, self.L)) + s.set_xlim(-0.6, self.L-0.4) + s.set_ylim(-0.6, self.L-0.4) s.invert_yaxis() for tic in s.xaxis.get_major_ticks(): tic.tick1On = tic.tick2On = False @@ -167,7 +178,8 @@ def plot(self, legend=True, stabs=True): tic.tick1On = tic.tick2On = False tic.label1On = tic.label2On = False s.grid() - if legend: s.legend() + if legend: + s.legend() return f, s def _wgraph(self, operator): @@ -176,6 +188,7 @@ def _wgraph(self, operator): nodes = zip(*self.Zstabilizer().nonzero()) elif operator == 'X': nodes = zip(*self.Xstabilizer().nonzero()) + def dist(node1, node2): dy = abs(node1[0]-node2[0]) dy = min(self.L-dy, dy) @@ -183,7 +196,7 @@ def dist(node1, node2): dx = min(self.L-dx, dx) return dx+dy g.add_weighted_edges_from((node1, node2, -dist(node1, node2)) - for node1, node2 in itertools.combinations(nodes, 2)) + for node1, node2 in itertools.combinations(nodes, 2)) return g def Zwgraph(self): @@ -205,17 +218,18 @@ def Zcorrections(self): ym, yM = 2*min(y1, y2), 2*max(y1, y2) if yM-ym > L: ym, yM = yM, ym+2*L - horizontal = yM if (x2-x1)*(y2-y1)<0 else ym + horizontal = yM if (x2-x1)*(y2-y1) < 0 else ym else: - horizontal = ym if (x2-x1)*(y2-y1)<0 else yM + horizontal = ym if (x2-x1)*(y2-y1) < 0 else yM xm, xM = min(x1, x2), max(x1, x2) if xM-xm > L/2: xm, xM = xM, xm+L vertical = xM else: vertical = xm - qubits.update((horizontal%(2*L), _%L) for _ in range(xm, xM)) - qubits.update(((_+1)%(2*L), vertical%L) for _ in range(ym, yM, 2)) + qubits.update((horizontal % (2*L), _ % L) for _ in range(xm, xM)) + qubits.update(((_+1) % (2*L), vertical % L) + for _ in range(ym, yM, 2)) return matches, qubits def Xcorrections(self): @@ -229,70 +243,74 @@ def Xcorrections(self): ym, yM = 2*min(y1, y2), 2*max(y1, y2) if yM-ym > L: ym, yM = yM, ym+2*L - horizontal = yM if (x2-x1)*(y2-y1)<0 else ym + horizontal = yM if (x2-x1)*(y2-y1) < 0 else ym else: - horizontal = ym if (x2-x1)*(y2-y1)<0 else yM + horizontal = ym if (x2-x1)*(y2-y1) < 0 else yM xm, xM = min(x1, x2), max(x1, x2) if xM-xm > L/2: xm, xM = xM, xm+L vertical = xM else: vertical = xm - qubits.update(((horizontal+1)%(2*L), (_+1)%L) for _ in range(xm, xM)) - qubits.update(((_+2)%(2*L), vertical%L) for _ in range(ym, yM, 2)) + qubits.update(((horizontal+1) % (2*L), (_+1) % L) + for _ in range(xm, xM)) + qubits.update(((_+2) % (2*L), vertical % L) + for _ in range(ym, yM, 2)) return matches, qubits def plot_corrections(self, s, plot_matches=False): '''Add to subplot ``s`` the corrections that have to be performed according to min. weight matching.''' def stitch_torus(y1, y2): - if abs(y1-y2)>L/2: - return (y1+L, y2-L) if y1 L/2: + return (y1+L, y2-L) if y1 < y2 else (y1-L, y2+L) return y1, y2 - def shorten(y1,y2): - if y1==y2: + + def shorten(y1, y2): + if y1 == y2: return y1, y2 - return (y1+0.15, y2-0.15) if y1= size: return + def do_normcenterstab(stabs, p): avg = (1-p)*2/3 avg_stab = 4*avg*(1-avg)**3 + 4*avg**3*(1-avg) var_stab = avg_stab-avg_stab**2 return (stabs - avg_stab)/var_stab**0.5 + def undo_normcenterstab(stabs, p): avg = (1-p)*2/3 avg_stab = 4*avg*(1-avg)**3 + 4*avg**3*(1-avg) var_stab = avg_stab-avg_stab**2 return stabs*var_stab**0.5 + avg_stab + def do_normcentererr(flips, p): avg = (1-p)*2/3 var = avg-avg**2 return (flips-avg)/var**0.5 + def undo_normcentererr(flips, p): avg = (1-p)*2/3 var = avg-avg**2 return flips*var**0.5 + avg + def smart_sample(H, stab, pred, sample, giveup): '''Sample `pred` until `H@sample==stab`. @@ -179,10 +207,11 @@ def smart_sample(H, stab, pred, sample, giveup): npsum = np.sum npdot = np.dot attempts = 1 - mismatch = stab!=npdot(H,sample)%2 + mismatch = stab != npdot(H, sample) % 2 while npany(mismatch) and attempts < giveup: - propagated = npany(H[mismatch,:], axis=0) - sample[propagated] = pred[propagated]>nprandomuniform(size=npsum(propagated)) - mismatch = stab!=npdot(H,sample)%2 + propagated = npany(H[mismatch, :], axis=0) + sample[propagated] = pred[propagated] > nprandomuniform( + size=npsum(propagated)) + mismatch = stab != npdot(H, sample) % 2 attempts += 1 return attempts diff --git a/train_network.py b/train_network.py index bc97c88..66a96e3 100644 --- a/train_network.py +++ b/train_network.py @@ -1,3 +1,7 @@ +import tqdm +import numpy as np +from codes import ToricCode +from neural import create_model, data_generator, do_normcenterstab, undo_normcentererr, smart_sample import argparse parser = argparse.ArgumentParser(description='Train a neural network to decode a code.', @@ -48,11 +52,6 @@ args = parser.parse_args() print(args) -from neural import create_model, data_generator, do_normcenterstab, undo_normcentererr, smart_sample -from codes import ToricCode -import numpy as np -import tqdm - if args.trainset: f = np.load(args.trainset) @@ -76,20 +75,21 @@ learning_rate=args.learningrate, normcentererr_p=args.prob if args.normcentererr else None, batchnorm=args.batchnorm - ) + ) L = args.dist code = ToricCode(L) out_dimZ = 2*L**2 * args.Zstab out_dimX = 2*L**2 * args.Xstab in_dim = L**2 * (args.Xstab+args.Zstab) -H = code.H(args.Zstab,args.Xstab) +H = code.H(args.Zstab, args.Xstab) if args.load: model.load_weights(args.load) if args.epochs: if args.trainset: - raise NotImplementedError("This is still using the OLD keras API. Update it!") + raise NotImplementedError( + "This is still using the OLD keras API. Update it!") x_train = [] y_train = [] if args.Zstab: @@ -104,7 +104,7 @@ nb_epoch=args.epochs, batch_size=args.batch, validation_data=(x_test, y_test) - ) + ) else: dat = data_generator(H, out_dimZ, out_dimX, in_dim, args.prob, args.batch, normcenterstab=args.normcenterstab, normcentererr=args.normcentererr) @@ -113,12 +113,13 @@ hist = model.fit(dat, steps_per_epoch=args.onthefly[0]//args.batch, epochs=args.epochs, validation_data=val, validation_steps=args.onthefly[1]//args.batch) # Add .weights.h5 extension for Keras 3.x compatibility - weights_filename = args.out + '.weights.h5' if not args.out.endswith('.weights.h5') else args.out + weights_filename = args.out + \ + '.weights.h5' if not args.out.endswith('.weights.h5') else args.out model.save_weights(weights_filename) with open(args.out+'.log', 'w') as f: f.write(str((hist.params, hist.history))) if args.eval: - E = ToricCode(L).E(args.Zstab,args.Xstab) + E = ToricCode(L).E(args.Zstab, args.Xstab) both = args.Zstab and args.Xstab if both: Hz = ToricCode(L).H(True, False) @@ -131,7 +132,8 @@ size = len(y_test) else: size = args.onthefly[1] - stabflipgen = data_generator(H, out_dimZ, out_dimX, in_dim, args.prob, batch_size=1, size=size) + stabflipgen = data_generator( + H, out_dimZ, out_dimX, in_dim, args.prob, batch_size=1, size=size) full_log = np.zeros((size, E.shape[0]+args.Zstab+args.Xstab), dtype=int) for i, (stab, flips) in tqdm.tqdm(enumerate(stabflipgen), total=size): if args.normcenterstab: @@ -143,28 +145,30 @@ pred = undo_normcentererr(pred, args.prob) stab = stab.ravel() flips = flips.ravel() - sample = pred>np.random.uniform(size=outlen) + sample = pred > np.random.uniform(size=outlen) if both: - attemptsZ = smart_sample(Hz, stab[:inlen//2], pred[:outlen//2], sample[:outlen//2], args.giveup) - attemptsX = smart_sample(Hx, stab[inlen//2:], pred[outlen//2:], sample[outlen//2:], args.giveup) + attemptsZ = smart_sample( + Hz, stab[:inlen//2], pred[:outlen//2], sample[:outlen//2], args.giveup) + attemptsX = smart_sample( + Hx, stab[inlen//2:], pred[outlen//2:], sample[outlen//2:], args.giveup) else: attempts = smart_sample(H, stab, pred, sample, args.giveup) - errors = E.dot((sample+flips)%2)%2 - if np.any(errors) or np.any(stab!=H.dot(sample)%2): + errors = E.dot((sample+flips) % 2) % 2 + if np.any(errors) or np.any(stab != H.dot(sample) % 2): c += 1 if both: cz += np.any(errors[:len(errors)//2]) cx += np.any(errors[len(errors)//2:]) if both: - full_log[i,:-2] = errors - full_log[i,-2] = attemptsZ - full_log[i,-1] = attemptsX + full_log[i, :-2] = errors + full_log[i, -2] = attemptsZ + full_log[i, -1] = attemptsX else: - full_log[i,:-1] = errors - full_log[i,-1] = attempts + full_log[i, :-1] = errors + full_log[i, -1] = attempts with open(args.out+'.eval', 'w') as f: if both: - f.write(str(((1-c/size),(1-cz/size),(1-cx/size)))) + f.write(str(((1-c/size), (1-cz/size), (1-cx/size)))) else: f.write(str(((1-c/size),))) np.savetxt(args.out+'.eval.log', full_log, fmt='%d')