diff --git a/runtime/onert/api/python/package/experimental/train/losses/registry.py b/runtime/onert/api/python/package/experimental/train/losses/registry.py index 9e558101fa5..1255493bf3f 100644 --- a/runtime/onert/api/python/package/experimental/train/losses/registry.py +++ b/runtime/onert/api/python/package/experimental/train/losses/registry.py @@ -4,11 +4,11 @@ class LossRegistry: """ - Registry for creating losses by name. + Registry for creating and mapping losses by name or instance. """ _losses = { "categorical_crossentropy": CategoricalCrossentropy, - "mean_squred_error": MeanSquaredError + "mean_squared_error": MeanSquaredError } @staticmethod @@ -23,3 +23,27 @@ def create_loss(name): if name not in LossRegistry._losses: raise ValueError(f"Unknown Loss: {name}. Custom loss is not supported yet") return LossRegistry._losses[name]() + + @staticmethod + def map_loss_function_to_enum(loss_instance): + """ + Maps a LossFunction instance to the appropriate enum value. + Args: + loss_instance (BaseLoss): An instance of a loss function. + Returns: + loss_type: Corresponding enum value for the loss function. + Raises: + TypeError: If the loss_instance is not a recognized LossFunction type. + """ + # Loss to Enum mapping + loss_to_enum = { + CategoricalCrossentropy: "CATEGORICAL_CROSSENTROPY", + MeanSquaredError: "MEAN_SQUARED_ERROR", + } + for loss_class, enum_value in loss_to_enum.items(): + if isinstance(loss_instance, loss_class): + return enum_value + raise TypeError( + f"Unsupported loss function type: {type(loss_instance).__name__}. " + f"Supported types are: {list(loss_to_enum.keys())}." + )