Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 71 additions & 19 deletions src/ctapipe/reco/reconstructor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import weakref
from abc import abstractmethod
from enum import Flag, auto

import astropy.units as u
import joblib
import numpy as np
from astropy.coordinates import AltAz, SkyCoord
from traitlets.config import Config

from ctapipe.containers import ArrayEventContainer, TelescopeImpactParameterContainer
from ctapipe.core import Provenance, QualityQuery, TelescopeComponent
from ctapipe.core.traits import Integer, List
from ctapipe.core import (
Provenance,
QualityQuery,
TelescopeComponent,
ToolConfigurationError,
)
from ctapipe.core.traits import Integer, List, classes_with_traits

from ..coordinates import shower_impact_distance

Expand Down Expand Up @@ -106,39 +111,90 @@
"""

@classmethod
def read(cls, path, parent=None, subarray=None, **kwargs):

Check failure on line 114 in src/ctapipe/reco/reconstructor.py

View check run for this annotation

CTAO-DPPS-SonarQube / ctapipe Sonarqube Results

src/ctapipe/reco/reconstructor.py#L114

Refactor this function to reduce its Cognitive Complexity from 17 to the 15 allowed.
"""Read a joblib-pickled reconstructor from ``path``
"""
Read a dictionary from ``path`` containing all necessary information
to construct an instance of a reconstructor (subclass).

Parameters
----------
path : str or pathlib.Path
Path to a Reconstructor instance pickled using joblib
Path to a dictionary containing all information about a
``Reconstructor`` (subclass).
parent : None or Component or Tool
Attach a new parent to the loaded class, this will properly
Attach a new parent to the loaded class.
subarray : SubarrayDescription
Attach a new subarray to the loaded reconstructor
A warning will be raised if the telescope types of the
subarray stored in the pickled class do not match with the
provided subarray.

**kwargs are set on the loaded instance
**kwargs are set on the constructed instance

Returns
-------
Reconstructor instance loaded from file
Reconstructor instance
"""
with open(path, "rb") as f:
instance = joblib.load(f)
dictionary = joblib.load(f)

meta = dictionary.pop("meta")
name = dictionary.pop("name")
config = Config(dictionary.pop("config"))
cls_attributes = dictionary.pop("cls_attributes")

if not isinstance(instance, cls):
if name not in [c.__name__ for c in classes_with_traits(cls)]:
raise TypeError(
f"{path} did not contain an instance of {cls}, got {instance}"
f"{path} does not contain information about {cls.__name__} or "
f"one of its subclasses, but instead about {name}."
)

# first deal with kwargs that would need "special" treatmet, parent and subarray
if parent is not None:
instance.parent = weakref.proxy(parent)
instance.log = parent.log.getChild(instance.__class__.__name__)
if name in parent.config.keys():
# Some configuration options should not be changed on a trained model.
forbidden_changes = [
"model_cls",
"norm_cls",
"sign_cls",
"model_config",
"norm_config",
"sign_config",
"log_target",
"features",
]
for trait_name in forbidden_changes:
if trait_name in parent.config[name].keys():
raise ToolConfigurationError(
f"{name}.{trait_name} can not be changed when "
f"a {name} is loaded."
)

changed_traits = parent.config[name]
# add loaded config of reconstructor to current config
parent.config.update(config)
# re-add changes to config done when the tool is called
parent.config[name].update(changed_traits)
else:
parent.config.update(config)

instance = Reconstructor.from_name(
name=name,
parent=parent,
subarray=dictionary["subarray"],
models=dictionary["models"],
)
else:
instance = Reconstructor.from_name(
name=name,
config=config,
subarray=dictionary["subarray"],
models=dictionary["models"],
)

# set class attributes not handled by __init__,
# e.g. the unit defined during SKLearnReconstructor.fit()
for attr, value in cls_attributes.items():
setattr(instance, attr, value)

if subarray is not None:
if instance.subarray.telescope_types != subarray.telescope_types:
Expand All @@ -147,11 +203,7 @@
)
instance.subarray = subarray

for attr, value in kwargs.items():
setattr(instance, attr, value)

# FIXME: we currently don't store metadata in the joblib / pickle files, see #2603
Provenance().add_input_file(path, role="reconstructor", add_meta=False)
Provenance().add_input_file(path, role="reconstructor", reference_meta=meta)
return instance


Expand Down
Loading
Loading