Skip to content

Commit d3c02b0

Browse files
authored
Mnist_binarize
1 parent aa4a7bc commit d3c02b0

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

dpcl_classifier/classifier_numba.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,20 @@ def run_experiment(states=10000, epochs=1, clauses=150, runs=100, pl=0.6, pu=0.8
170170
X_train = np.where(X_train.reshape((X_train.shape[0], 28 * 28)) > 75, 1, 0)
171171
X_test = np.where(X_test.reshape((X_test.shape[0], 28 * 28)) > 75, 1, 0)
172172

173-
mask_train = np.isin(Y_train, [0, 1])
173+
digit1 = 1
174+
digit2 = 8
175+
176+
mask_train = np.isin(Y_train, [digit1, digit2])
174177
X_train_filtered = X_train[mask_train]
175178
Y_train_filtered = Y_train[mask_train]
179+
Y_train_filtered[Y_train_filtered == digit1] = 0
180+
Y_train_filtered[Y_train_filtered == digit2] = 1
176181

177-
mask_test = np.isin(Y_test, [0, 1])
182+
mask_test = np.isin(Y_test, [digit1, digit2])
178183
X_test_filtered = X_test[mask_test]
179184
Y_test_filtered = Y_test[mask_test]
185+
Y_test_filtered[Y_test_filtered == digit1] = 0
186+
Y_test_filtered[Y_test_filtered == digit2] = 1
180187

181188
Accuracies = []
182189

0 commit comments

Comments
 (0)