Skip to content

Commit 4707d3e

Browse files
fix test (#2528)
1 parent 6f75a45 commit 4707d3e

File tree

3 files changed

+33
-12
lines changed

3 files changed

+33
-12
lines changed

Diff for: aeon/clustering/deep_learning/_ae_fcn.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def _fit(self, X):
317317
outputs=X,
318318
batch_size=mini_batch_size,
319319
epochs=self.n_epochs,
320+
verbose=self.verbose,
320321
)
321322

322323
try:
@@ -345,6 +346,7 @@ def _fit_multi_rec_model(
345346
outputs,
346347
batch_size,
347348
epochs,
349+
verbose,
348350
):
349351
import tensorflow as tf
350352

@@ -451,9 +453,10 @@ def loss(y_true, y_pred):
451453
epoch_loss /= num_batches
452454
history["loss"].append(epoch_loss)
453455

454-
sys.stdout.write(
455-
"Training loss at epoch %d: %.4f\n" % (epoch, float(epoch_loss))
456-
)
456+
if verbose:
457+
sys.stdout.write(
458+
"Training loss at epoch %d: %.4f\n" % (epoch, float(epoch_loss))
459+
)
457460

458461
for callback in self.callbacks_:
459462
callback.on_epoch_end(epoch, {"loss": float(epoch_loss)})

Diff for: aeon/clustering/deep_learning/_ae_resnet.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def _fit(self, X):
329329
outputs=X,
330330
batch_size=mini_batch_size,
331331
epochs=self.n_epochs,
332+
verbose=self.verbose,
332333
)
333334

334335
try:
@@ -359,6 +360,7 @@ def _fit_multi_rec_model(
359360
outputs,
360361
batch_size,
361362
epochs,
363+
verbose,
362364
):
363365
import tensorflow as tf
364366

@@ -463,9 +465,10 @@ def loss(y_true, y_pred):
463465
epoch_loss /= num_batches
464466
history["loss"].append(epoch_loss)
465467

466-
sys.stdout.write(
467-
"Training loss at epoch %d: %.4f\n" % (epoch, float(epoch_loss))
468-
)
468+
if verbose:
469+
sys.stdout.write(
470+
"Training loss at epoch %d: %.4f\n" % (epoch, float(epoch_loss))
471+
)
469472

470473
for callback in self.callbacks_:
471474
callback.on_epoch_end(epoch, {"loss": float(epoch_loss)})

Diff for: aeon/testing/estimator_checking/_yield_clustering_checks.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,33 @@ def check_clustering_random_state_deep_learning(estimator, datatype):
7777
deep_clr1 = _clone_estimator(estimator, random_state=random_state)
7878
deep_clr1.fit(FULL_TEST_DATA_DICT[datatype]["train"][0])
7979

80-
layers1 = deep_clr1.training_model_.layers[1:]
80+
encoder_layers1 = deep_clr1.training_model_.layers[1].layers[1:]
81+
decoder_layers1 = deep_clr1.training_model_.layers[2].layers[1:]
8182

8283
deep_clr2 = _clone_estimator(estimator, random_state=random_state)
8384
deep_clr2.fit(FULL_TEST_DATA_DICT[datatype]["train"][0])
8485

85-
layers2 = deep_clr2.training_model_.layers[1:]
86+
encoder_layers2 = deep_clr2.training_model_.layers[1].layers[1:]
87+
decoder_layers2 = deep_clr2.training_model_.layers[2].layers[1:]
8688

87-
assert len(layers1) == len(layers2)
89+
assert len(encoder_layers1) == len(encoder_layers2)
90+
assert len(decoder_layers1) == len(decoder_layers2)
8891

89-
for i in range(len(layers1)):
90-
weights1 = layers1[i].get_weights()
91-
weights2 = layers2[i].get_weights()
92+
for i in range(len(encoder_layers1)):
93+
weights1 = encoder_layers1[i].get_weights()
94+
weights2 = encoder_layers2[i].get_weights()
95+
96+
assert len(weights1) == len(weights2)
97+
98+
for j in range(len(weights1)):
99+
_weight1 = np.asarray(weights1[j])
100+
_weight2 = np.asarray(weights2[j])
101+
102+
np.testing.assert_almost_equal(_weight1, _weight2, 4)
103+
104+
for i in range(len(decoder_layers1)):
105+
weights1 = decoder_layers1[i].get_weights()
106+
weights2 = decoder_layers2[i].get_weights()
92107

93108
assert len(weights1) == len(weights2)
94109

0 commit comments

Comments
 (0)