From ea8a975191e164ebe9128ea98e8973c927645f41 Mon Sep 17 00:00:00 2001 From: Raj Sinha Date: Mon, 22 Jul 2024 15:36:13 -0700 Subject: [PATCH] Switch to weighted metrics in the compile step of the supervised model. PiperOrigin-RevId: 654919610 --- spade_anomaly_detection/supervised_model.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/spade_anomaly_detection/supervised_model.py b/spade_anomaly_detection/supervised_model.py index 5592a50..71b315e 100644 --- a/spade_anomaly_detection/supervised_model.py +++ b/spade_anomaly_detection/supervised_model.py @@ -64,12 +64,16 @@ def save(self, save_location: str) -> None: save_location: String denoting a Google Cloud Storage location, or local disk path. Note that local assets will be deleted when the VM running this container is shutdown at the end of the training job. + + Raises: + ValueError: If the supervised model was not initialized. """ - if self.supervised_model is not None: - self.supervised_model.save(save_location) - logging.info('Saved model assets to %s', save_location) + if self.supervised_model is None: + raise ValueError('Supervised model was not initialized.') else: - logging.warning('No model to save.') + self.supervised_model.save(save_location) # pytype: disable=attribute-error + + logging.info('Saved model assets to %s', save_location) @dataclasses.dataclass @@ -133,7 +137,7 @@ def __init__( **dataclasses.asdict(self.supervised_parameters) ) self.supervised_model.compile( - metrics=[ + weighted_metrics=[ tf.keras.metrics.AUC(name='Supervised_Model_AUC'), tf.keras.metrics.Precision( thresholds=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],