diff --git a/src/anemoi/utils/config.py b/src/anemoi/utils/config.py index 9c89b4b..c547672 100644 --- a/src/anemoi/utils/config.py +++ b/src/anemoi/utils/config.py @@ -50,14 +50,14 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for k, v in self.items(): - if isinstance(v, dict): + if isinstance(v, dict) or is_omegaconf_dict(v): self[k] = DotDict(v) - if isinstance(v, list): - self[k] = [DotDict(i) if isinstance(i, dict) else i for i in v] + if isinstance(v, list) or is_omegaconf_list(v): + self[k] = [DotDict(i) if isinstance(i, dict) or is_omegaconf_dict(i) else i for i in v] if isinstance(v, tuple): - self[k] = [DotDict(i) if isinstance(i, dict) else i for i in v] + self[k] = [DotDict(i) if isinstance(i, dict) or is_omegaconf_dict(i) else i for i in v] @classmethod def from_file(cls, path: str): @@ -106,6 +106,24 @@ def __repr__(self) -> str: return f"DotDict({super().__repr__()})" +def is_omegaconf_dict(value) -> bool: + try: + from omegaconf import DictConfig + + return isinstance(value, DictConfig) + except ImportError: + return False + + +def is_omegaconf_list(value) -> bool: + try: + from omegaconf import ListConfig + + return isinstance(value, ListConfig) + except ImportError: + return False + + CONFIG = {} CHECKED = {} CONFIG_LOCK = threading.RLock()