-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[onert/python] Introduce LossRegistry (#14574)
This commit introduces LossRegistries that can create loss functions by names. ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
- Loading branch information
Showing
1 changed file
with
49 additions
and
0 deletions.
There are no files selected for viewing
49 changes: 49 additions & 0 deletions
49
runtime/onert/api/python/package/experimental/train/losses/registry.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from onert.native.libnnfw_api_pybind import loss as loss_type | ||
from .cce import CategoricalCrossentropy | ||
from .mse import MeanSquaredError | ||
|
||
|
||
class LossRegistry: | ||
""" | ||
Registry for creating and mapping losses by name or instance. | ||
""" | ||
_losses = { | ||
"categorical_crossentropy": CategoricalCrossentropy, | ||
"mean_squared_error": MeanSquaredError | ||
} | ||
|
||
@staticmethod | ||
def create_loss(name): | ||
""" | ||
Create a loss instance by name. | ||
Args: | ||
name (str): Name of the loss. | ||
Returns: | ||
BaseLoss: Loss instance. | ||
""" | ||
if name not in LossRegistry._losses: | ||
raise ValueError(f"Unknown Loss: {name}. Custom loss is not supported yet") | ||
return LossRegistry._losses[name]() | ||
|
||
@staticmethod | ||
def map_loss_function_to_enum(loss_instance): | ||
""" | ||
Maps a LossFunction instance to the appropriate enum value. | ||
Args: | ||
loss_instance (BaseLoss): An instance of a loss function. | ||
Returns: | ||
loss_type: Corresponding enum value for the loss function. | ||
Raises: | ||
TypeError: If the loss_instance is not a recognized LossFunction type. | ||
""" | ||
# Loss to Enum mapping | ||
loss_to_enum = { | ||
CategoricalCrossentropy: loss_type.CATEGORICAL_CROSSENTROPY, | ||
MeanSquaredError: loss_type.MEAN_SQUARED_ERROR | ||
} | ||
for loss_class, enum_value in loss_to_enum.items(): | ||
if isinstance(loss_instance, loss_class): | ||
return enum_value | ||
raise TypeError( | ||
f"Unsupported loss function type: {type(loss_instance).__name__}. " | ||
f"Supported types are: {list(loss_to_enum.keys())}.") |