Skip to content

Commit

Permalink
Add mapping function
Browse files Browse the repository at this point in the history
  • Loading branch information
ragmani committed Jan 22, 2025
1 parent ff4393b commit fb36281
Showing 1 changed file with 26 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

class LossRegistry:
"""
Registry for creating losses by name.
Registry for creating and mapping losses by name or instance.
"""
_losses = {
"categorical_crossentropy": CategoricalCrossentropy,
"mean_squred_error": MeanSquaredError
"mean_squared_error": MeanSquaredError
}

@staticmethod
Expand All @@ -23,3 +23,27 @@ def create_loss(name):
if name not in LossRegistry._losses:
raise ValueError(f"Unknown Loss: {name}. Custom loss is not supported yet")
return LossRegistry._losses[name]()

@staticmethod
def map_loss_function_to_enum(loss_instance):
"""
Maps a LossFunction instance to the appropriate enum value.
Args:
loss_instance (BaseLoss): An instance of a loss function.
Returns:
loss_type: Corresponding enum value for the loss function.
Raises:
TypeError: If the loss_instance is not a recognized LossFunction type.
"""
# Loss to Enum mapping
loss_to_enum = {
CategoricalCrossentropy: "CATEGORICAL_CROSSENTROPY",
MeanSquaredError: "MEAN_SQUARED_ERROR",
}
for loss_class, enum_value in loss_to_enum.items():
if isinstance(loss_instance, loss_class):
return enum_value
raise TypeError(
f"Unsupported loss function type: {type(loss_instance).__name__}. "
f"Supported types are: {list(loss_to_enum.keys())}."
)

0 comments on commit fb36281

Please sign in to comment.