diff --git a/src/ctapipe/reco/reconstructor.py b/src/ctapipe/reco/reconstructor.py index d09c7d4b362..ac662e790f3 100644 --- a/src/ctapipe/reco/reconstructor.py +++ b/src/ctapipe/reco/reconstructor.py @@ -1,4 +1,3 @@ -import weakref from abc import abstractmethod from enum import Flag, auto @@ -6,10 +5,16 @@ import joblib import numpy as np from astropy.coordinates import AltAz, SkyCoord +from traitlets.config import Config from ctapipe.containers import ArrayEventContainer, TelescopeImpactParameterContainer -from ctapipe.core import Provenance, QualityQuery, TelescopeComponent -from ctapipe.core.traits import Integer, List +from ctapipe.core import ( + Provenance, + QualityQuery, + TelescopeComponent, + ToolConfigurationError, +) +from ctapipe.core.traits import Integer, List, classes_with_traits from ..coordinates import shower_impact_distance @@ -107,38 +112,89 @@ def __call__(self, event: ArrayEventContainer): @classmethod def read(cls, path, parent=None, subarray=None, **kwargs): - """Read a joblib-pickled reconstructor from ``path`` + """ + Read a dictionary from ``path`` containing all necessary information + to construct an instance of a reconstructor (subclass). Parameters ---------- path : str or pathlib.Path - Path to a Reconstructor instance pickled using joblib + Path to a dictionary containing all information about a + ``Reconstructor`` (subclass). parent : None or Component or Tool - Attach a new parent to the loaded class, this will properly + Attach a new parent to the loaded class. subarray : SubarrayDescription Attach a new subarray to the loaded reconstructor A warning will be raised if the telescope types of the subarray stored in the pickled class do not match with the provided subarray. - **kwargs are set on the loaded instance + **kwargs are set on the constructed instance Returns ------- - Reconstructor instance loaded from file + Reconstructor instance """ with open(path, "rb") as f: - instance = joblib.load(f) + dictionary = joblib.load(f) + + meta = dictionary.pop("meta") + name = dictionary.pop("name") + config = Config(dictionary.pop("config")) + cls_attributes = dictionary.pop("cls_attributes") - if not isinstance(instance, cls): + if name not in [c.__name__ for c in classes_with_traits(cls)]: raise TypeError( - f"{path} did not contain an instance of {cls}, got {instance}" + f"{path} does not contain information about {cls.__name__} or " + f"one of its subclasses, but instead about {name}." ) - # first deal with kwargs that would need "special" treatmet, parent and subarray if parent is not None: - instance.parent = weakref.proxy(parent) - instance.log = parent.log.getChild(instance.__class__.__name__) + if name in parent.config.keys(): + # Some configuration options should not be changed on a trained model. + forbidden_changes = [ + "model_cls", + "norm_cls", + "sign_cls", + "model_config", + "norm_config", + "sign_config", + "log_target", + "features", + ] + for trait_name in forbidden_changes: + if trait_name in parent.config[name].keys(): + raise ToolConfigurationError( + f"{name}.{trait_name} can not be changed when " + f"a {name} is loaded." + ) + + changed_traits = parent.config[name] + # add loaded config of reconstructor to current config + parent.config.update(config) + # re-add changes to config done when the tool is called + parent.config[name].update(changed_traits) + else: + parent.config.update(config) + + instance = Reconstructor.from_name( + name=name, + parent=parent, + subarray=dictionary["subarray"], + models=dictionary["models"], + ) + else: + instance = Reconstructor.from_name( + name=name, + config=config, + subarray=dictionary["subarray"], + models=dictionary["models"], + ) + + # set class attributes not handled by __init__, + # e.g. the unit defined during SKLearnReconstructor.fit() + for attr, value in cls_attributes.items(): + setattr(instance, attr, value) if subarray is not None: if instance.subarray.telescope_types != subarray.telescope_types: @@ -147,11 +203,7 @@ def read(cls, path, parent=None, subarray=None, **kwargs): ) instance.subarray = subarray - for attr, value in kwargs.items(): - setattr(instance, attr, value) - - # FIXME: we currently don't store metadata in the joblib / pickle files, see #2603 - Provenance().add_input_file(path, role="reconstructor", add_meta=False) + Provenance().add_input_file(path, role="reconstructor", reference_meta=meta) return instance diff --git a/src/ctapipe/reco/sklearn.py b/src/ctapipe/reco/sklearn.py index acc77e964ad..0e711ed1a44 100644 --- a/src/ctapipe/reco/sklearn.py +++ b/src/ctapipe/reco/sklearn.py @@ -1,6 +1,7 @@ """ Component Wrappers around sklearn models """ + import pathlib from abc import abstractmethod from collections import defaultdict @@ -123,9 +124,7 @@ class SKLearnReconstructor(Reconstructor): help="If given, load serialized model from this path.", ).tag(config=True) - def __init__( - self, subarray=None, atmosphere_profile=None, models=None, n_jobs=None, **kwargs - ): + def __init__(self, subarray=None, atmosphere_profile=None, models=None, **kwargs): # Run the Component __init__ first to handle the configuration # and make `self.load_path` available Component.__init__(self, **kwargs) @@ -188,7 +187,7 @@ def __call__(self, event: ArrayEventContainer) -> None: """ @abstractmethod - def predict_table(self, key, table: Table) -> Table: + def predict_table(self, key, table: Table) -> dict[ReconstructionProperty, Table]: """ Predict on a table of events. @@ -207,16 +206,6 @@ def predict_table(self, key, table: Table) -> Table: container definition(s) """ - def write(self, path, overwrite=False): - path = pathlib.Path(path) - - if path.exists() and not overwrite: - raise OSError(f"Path {path} exists and overwrite=False") - - with path.open("wb") as f: - Provenance().add_output_file(path, role="ml-models") - joblib.dump(self, f, compress=True) - @lazyproperty def instrument_table(self): return QTable(self.subarray.to_table("joined")) @@ -311,6 +300,52 @@ def _table_to_y(self, table, mask=None): return np.log(y) return y + def write(self, path, meta=None, overwrite=False): + """ + Save a dictionary using joblib-pickle, which contains all + information/settings about an instance of a + ``SKLearnRegressionReconstructor`` (subclass). + + Parameters + ---------- + path : str or pathlib.Path + Path to which the dictionary will be saved. + meta : dict + Metadata + overwrite : Bool + Whether to overwrite, if ``path`` already exists. + """ + path = pathlib.Path(path) + + if path.exists() and not overwrite: + raise OSError(f"Path {path} exists and overwrite=False") + + dictionary = { + "name": self.__class__.__name__, + "subarray": self.subarray, + "models": self._models, + "config": { + self.__class__.__name__: { + "prefix": self.prefix, + "log_target": self.log_target, + "model_cls": self.model_cls, + "model_config": self.model_config, + "features": self.features, + "stereo_combiner_cls": self.stereo_combiner_cls, + "FeatureGenerator": {"features": self.feature_generator.features}, + "QualityQuery": { + "quality_criteria": self.quality_query.quality_criteria + }, + self.stereo_combiner_cls: {"weights": self.stereo_combiner.weights}, + } + }, + "cls_attributes": {"unit": self.unit}, + "meta": meta, + } + with path.open("wb") as f: + Provenance().add_output_file(path, role="ml-reconstructor") + joblib.dump(dictionary, f, compress=True) + class SKLearnClassificationReconstructor(SKLearnReconstructor): """Base class for classification tasks.""" @@ -386,6 +421,53 @@ def _predict_score(self, key, table): def _get_positive_index(self, key): return np.nonzero(self._models[key].classes_ == self.positive_class)[0][0] + def write(self, path, meta=None, overwrite=False): + """ + Save a dictionary using joblib-pickle, which contains all + information/settings about an instance of a + ``SKLearnClassificationReconstructor`` (subclass). + + Parameters + ---------- + path : str or pathlib.Path + Path to which the dictionary will be saved. + meta : dict + Metadata + overwrite : Bool + Whether to overwrite, if ``path`` already exists. + """ + path = pathlib.Path(path) + + if path.exists() and not overwrite: + raise OSError(f"Path {path} exists and overwrite=False") + + dictionary = { + "name": self.__class__.__name__, + "subarray": self.subarray, + "models": self._models, + "config": { + self.__class__.__name__: { + "prefix": self.prefix, + "invalid_class": self.invalid_class, + "positive_class": self.positive_class, + "model_cls": self.model_cls, + "model_config": self.model_config, + "features": self.features, + "stereo_combiner_cls": self.stereo_combiner_cls, + "FeatureGenerator": {"features": self.feature_generator.features}, + "QualityQuery": { + "quality_criteria": self.quality_query.quality_criteria + }, + self.stereo_combiner_cls: {"weights": self.stereo_combiner.weights}, + } + }, + "cls_attributes": {"unit": self.unit}, + "meta": meta, + } + with path.open("wb") as f: + Provenance().add_output_file(path, role="ml-reconstructor") + joblib.dump(dictionary, f, compress=True) + class EnergyRegressor(SKLearnRegressionReconstructor): """ @@ -452,14 +534,13 @@ class ParticleClassifier(SKLearnClassificationReconstructor): """Predict dl2 particle classification.""" target = "true_shower_primary_id" + property = ReconstructionProperty.PARTICLE_TYPE positive_class = traits.Integer( default_value=0, help="Particle id (in simtel system) of the positive class. Default is 0 for gammas.", ).tag(config=True) - property = ReconstructionProperty.PARTICLE_TYPE - def __call__(self, event: ArrayEventContainer) -> None: for tel_id in event.trigger.tels_with_trigger: table = collect_features(event, tel_id, self.instrument_table) @@ -518,6 +599,7 @@ class DispReconstructor(Reconstructor): """ target = "true_disp" + property = ReconstructionProperty.GEOMETRY prefix = traits.Unicode( default_value="disp", @@ -599,7 +681,7 @@ def __init__(self, subarray=None, atmosphere_profile=None, models=None, **kwargs self.stereo_combiner = StereoCombiner.from_name( self.stereo_combiner_cls, prefix=self.prefix, - property=ReconstructionProperty.GEOMETRY, + property=self.property, parent=self, ) else: @@ -614,6 +696,9 @@ def __init__(self, subarray=None, atmosphere_profile=None, models=None, **kwargs self.__dict__.update(loaded.__dict__) self.subarray = subarray + if self.prefix is None: + self.prefix = "disp" + def _new_models(self): norm_cfg = self.norm_config sign_cfg = self.sign_config @@ -654,33 +739,6 @@ def fit(self, key, table): self._models[key][0].fit(X, norm) self._models[key][1].fit(X, sign) - def write(self, path, overwrite=False): - path = pathlib.Path(path) - - if path.exists() and not overwrite: - raise OSError(f"Path {path} exists and overwrite=False") - - with path.open("wb") as f: - Provenance().add_output_file(path, role="ml-models") - joblib.dump(self, f, compress=True) - - @classmethod - def read(cls, path, **kwargs): - with open(path, "rb") as f: - instance = joblib.load(f) - - for attr, value in kwargs.items(): - setattr(instance, attr, value) - - if not isinstance(instance, cls): - raise TypeError( - f"{path} did not contain an instance of {cls}, got {instance}" - ) - - # FIXME: we currently don't store metadata in the joblib / pickle files, see #2603 - Provenance().add_input_file(path, role="ml-models", add_meta=False) - return instance - @lazyproperty def instrument_table(self): return self.subarray.to_table("joined") @@ -870,6 +928,53 @@ def _set_n_jobs(self, n_jobs): disp.n_jobs = n_jobs.new sign.n_jobs = n_jobs.new + def write(self, path, meta=None, overwrite=False): + """ + Save a dictionary using joblib-pickle, which contains all + information/settings about an instance of a ``DispReconstructor`` . + + Parameters + ---------- + path : str or pathlib.Path + Path to which the dictionary will be saved. + meta : dict + Metadata + overwrite : Bool + Whether to overwrite, if ``path`` already exists. + """ + path = pathlib.Path(path) + + if path.exists() and not overwrite: + raise OSError(f"Path {path} exists and overwrite=False") + + dictionary = { + "name": self.__class__.__name__, + "subarray": self.subarray, + "models": self._models, + "config": { + self.__class__.__name__: { + "prefix": self.prefix, + "log_target": self.log_target, + "norm_cls": self.norm_cls, + "sign_cls": self.sign_cls, + "norm_config": self.norm_config, + "sign_config": self.sign_config, + "features": self.features, + "stereo_combiner_cls": self.stereo_combiner_cls, + "FeatureGenerator": {"features": self.feature_generator.features}, + "QualityQuery": { + "quality_criteria": self.quality_query.quality_criteria + }, + self.stereo_combiner_cls: {"weights": self.stereo_combiner.weights}, + } + }, + "cls_attributes": {"unit": self.unit}, + "meta": meta, + } + with path.open("wb") as f: + Provenance().add_output_file(path, role="ml-reconstructor") + joblib.dump(dictionary, f, compress=True) + class CrossValidator(Component): """Class to train sklearn based reconstructors in a cross validation.""" diff --git a/src/ctapipe/tools/train_disp_reconstructor.py b/src/ctapipe/tools/train_disp_reconstructor.py index a125ff753cf..8435735db7f 100644 --- a/src/ctapipe/tools/train_disp_reconstructor.py +++ b/src/ctapipe/tools/train_disp_reconstructor.py @@ -186,7 +186,6 @@ def finish(self): Write-out trained models and cross-validation results. """ self.log.info("Writing output") - self.models.n_jobs = None self.models.write(self.output_path, overwrite=self.overwrite) self.loader.close() self.cross_validate.close() diff --git a/src/ctapipe/tools/train_energy_regressor.py b/src/ctapipe/tools/train_energy_regressor.py index 408e1ed7d51..87773279fe2 100644 --- a/src/ctapipe/tools/train_energy_regressor.py +++ b/src/ctapipe/tools/train_energy_regressor.py @@ -1,6 +1,7 @@ """ Tool for training the EnergyRegressor """ + import numpy as np from ctapipe.core import Tool @@ -141,7 +142,6 @@ def finish(self): Write-out trained models and cross-validation results. """ self.log.info("Writing output") - self.regressor.n_jobs = None self.regressor.write(self.output_path, overwrite=self.overwrite) self.loader.close() self.cross_validate.close() diff --git a/src/ctapipe/tools/train_particle_classifier.py b/src/ctapipe/tools/train_particle_classifier.py index 7b235ab0bae..c74278df26f 100644 --- a/src/ctapipe/tools/train_particle_classifier.py +++ b/src/ctapipe/tools/train_particle_classifier.py @@ -1,6 +1,7 @@ """ Tool for training the ParticleClassifier """ + import numpy as np from astropy.table import vstack @@ -232,7 +233,6 @@ def finish(self): Write-out trained models and cross-validation results. """ self.log.info("Writing output") - self.classifier.n_jobs = None self.classifier.write(self.output_path, overwrite=self.overwrite) self.signal_loader.close() self.background_loader.close()