From 655c1be8ca805bb9d3f07b28c2d6c131cf01c173 Mon Sep 17 00:00:00 2001 From: Mathieu de Bony Date: Thu, 19 Mar 2026 17:28:51 +0100 Subject: [PATCH 1/7] Raw import from DL3 branch --- .codespell-ignores | 1 + src/ctapipe/io/dl3.py | 1005 ++++++++++++++++++++++++++++++ src/ctapipe/io/tests/test_dl3.py | 611 ++++++++++++++++++ 3 files changed, 1617 insertions(+) create mode 100644 src/ctapipe/io/dl3.py create mode 100644 src/ctapipe/io/tests/test_dl3.py diff --git a/.codespell-ignores b/.codespell-ignores index 5a3c6e0794e..2b905ce498a 100644 --- a/.codespell-ignores +++ b/.codespell-ignores @@ -4,3 +4,4 @@ nd studi referenc FRAM +livetime diff --git a/src/ctapipe/io/dl3.py b/src/ctapipe/io/dl3.py new file mode 100644 index 00000000000..eedba19617f --- /dev/null +++ b/src/ctapipe/io/dl3.py @@ -0,0 +1,1005 @@ +from abc import abstractmethod +from collections.abc import Mapping +from datetime import UTC, datetime +from functools import lru_cache +from typing import Any, Dict, List, Tuple + +import astropy.units as u +import numpy as np +from astropy.coordinates import ( + ICRS, + AltAz, + BaseCoordinateFrame, + EarthLocation, + SkyCoord, +) +from astropy.io import fits +from astropy.io.fits import Header +from astropy.io.fits.hdu.base import ExtensionHDU +from astropy.table import QTable, Table +from astropy.time import Time, TimeDelta + +from ..compat import COPY_IF_NEEDED +from ..core import Component +from ..core.traits import AstroTime, Bool +from ..version import version as ctapipe_version + +__all__ = ["DL3EventsWriter", "DL3GADFEventsWriter"] + + +class DL3EventsWriter(Component): + """ + Base class for writing a DL3 file + """ + + overwrite = Bool( + default_value=False, + help="If true, allow to overwrite already existing output file", + ).tag(config=True) + + optional_dl3_columns = Bool( + default_value=False, help="If true add optional columns to produce file" + ).tag(config=True) + + raise_error_for_optional = Bool( + default_value=True, + help="If true will raise error in the case optional column are missing", + ).tag(config=True) + + reference_time = AstroTime( + default_value=Time("2018-01-01T00:00:00", scale="tai"), + help="The reference time that will be used in the FITS file", + ).tag(config=True) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._obs_id = None + self._events = None + self._pointing = None + self._pointing_mode = None + self._gti = None + self._aeff = None + self._psf = None + self._edisp = None + self._bkg = None + self._livetime_fraction = None + self._location = None + self._telescope_information = None + self._target_information = None + self._software_information = None + + @abstractmethod + def write_file(self, path): + pass + + @property + def obs_id(self) -> int: + return self._obs_id + + @obs_id.setter + def obs_id(self, obs_id: int): + """ + Parameters + ---------- + obs_id : int + Observation ID + """ + if self._obs_id is not None: + self.log.warning( + "Obs id for DL3 file was already set, replacing current obs id" + ) + if obs_id is not None: + if not isinstance(obs_id, (int, np.integer)) or isinstance(obs_id, bool): + raise TypeError("obs_id must be an integer.") + if obs_id < 0: + raise ValueError("obs_id must be >= 0.") + self._obs_id = obs_id + + @property + def events(self) -> QTable: + return self._events + + @events.setter + def events(self, events: QTable): + """ + Parameters + ---------- + events : QTable + A table with a line for each event + """ + if self._events is not None: + self.log.warning( + "Events table for DL3 file was already set, replacing current event table" + ) + if events is not None and not isinstance(events, (QTable, Table)): + raise TypeError("events must be an astropy Table or QTable.") + self._events = events + + @property + def pointing(self) -> List[Tuple[Time, SkyCoord]]: + return self._pointing + + @pointing.setter + def pointing(self, pointing: List[Tuple[Time, SkyCoord]]): + """ + Parameters + ---------- + pointing : List[Tuple[Time, SkyCoord]] + A list with for each entry containing the time at which the coordinate where evaluated and the associated coordinates + """ + if self._pointing is not None: + self.log.warning( + "Pointing for DL3 file was already set, replacing current pointing" + ) + if pointing is not None: + if not isinstance(pointing, (list, tuple)): + raise TypeError("pointing must be a list of (time, coordinate) pairs.") + + for i, value in enumerate(pointing): + if not isinstance(value, (list, tuple)) or len(value) != 2: + raise ValueError( + f"pointing[{i}] must be a (time, coordinate) pair." + ) + + coordinate = value[1] + if not isinstance(coordinate, (SkyCoord, BaseCoordinateFrame)): + raise TypeError( + f"pointing[{i}].coordinate must be a SkyCoord or coordinate frame." + ) + self._pointing = pointing + + @property + def pointing_mode(self) -> str: + return self._pointing_mode + + @pointing_mode.setter + def pointing_mode(self, pointing_mode: str): + """ + Parameters + ---------- + pointing_mode : str + The name of the pointing mode used for the observation + """ + if self._pointing_mode is not None: + self.log.warning( + "Pointing for DL3 file was already set, replacing current pointing" + ) + if pointing_mode is not None: + if not isinstance(pointing_mode, str): + raise TypeError("pointing_mode must be a string.") + + pointing_mode = pointing_mode.strip().upper() + if pointing_mode not in {"TRACK", "DRIFT"}: + raise ValueError("pointing_mode must be either 'TRACK' or 'DRIFT'.") + self._pointing_mode = pointing_mode + + @property + def gti(self) -> List[Tuple[Time, Time]]: + return self._gti + + @gti.setter + def gti(self, gti: List[Tuple[Time, Time]]): + """ + Parameters + ---------- + gti : List[Tuple[Time, Time]] + A list with for each entry containing the time the start and stop time of the good time intervals + """ + if self._gti is not None: + self.log.warning("GTI for DL3 file was already set, replacing current gti") + if gti is not None: + if not isinstance(gti, (list, tuple)): + raise TypeError("gti must be a list of (start, stop) pairs.") + + for i, value in enumerate(gti): + if not isinstance(value, (list, tuple)) or len(value) != 2: + raise ValueError(f"gti[{i}] must be a (start, stop) pair.") + self._gti = gti + + @property + def aeff(self) -> ExtensionHDU: + return self._aeff + + @aeff.setter + def aeff(self, aeff: ExtensionHDU): + """ + Parameters + ---------- + aeff : ExtensionHDU + The effective area HDU read from the fits file containing IRFs + """ + if self._aeff is not None: + self.log.warning( + "Effective area for DL3 file was already set, replacing current effective area" + ) + if aeff is not None and not isinstance(aeff, ExtensionHDU): + raise TypeError("aeff must be a FITS ExtensionHDU.") + self._aeff = aeff + + @property + def psf(self) -> ExtensionHDU: + return self._psf + + @psf.setter + def psf(self, psf: ExtensionHDU): + """ + Parameters + ---------- + psf : ExtensionHDU + The PSF HDU read from the fits file containing IRFs + """ + if self._psf is not None: + self.log.warning("PSF for DL3 file was already set, replacing current PSF") + if psf is not None and not isinstance(psf, ExtensionHDU): + raise TypeError("psf must be a FITS ExtensionHDU.") + self._psf = psf + + @property + def edisp(self) -> ExtensionHDU: + return self._edisp + + @edisp.setter + def edisp(self, edisp: ExtensionHDU): + """ + Parameters + ---------- + edisp : ExtensionHDU + The EDISP HDU read from the fits file containing IRFs + """ + if self._edisp is not None: + self.log.warning( + "EDISP for DL3 file was already set, replacing current EDISP" + ) + if edisp is not None and not isinstance(edisp, ExtensionHDU): + raise TypeError("edisp must be a FITS ExtensionHDU.") + self._edisp = edisp + + @property + def bkg(self) -> ExtensionHDU: + return self._bkg + + @bkg.setter + def bkg(self, bkg: ExtensionHDU): + """ + Parameters + ---------- + bkg : ExtensionHDU + The background HDU read from the fits file containing IRFs + """ + if self._bkg is not None: + self.log.warning( + "Background for DL3 file was already set, replacing current background" + ) + if bkg is not None and not isinstance(bkg, ExtensionHDU): + raise TypeError("bkg must be a FITS ExtensionHDU.") + self._bkg = bkg + + @property + def location(self) -> EarthLocation: + return self._location + + @location.setter + def location(self, location: EarthLocation): + """ + Parameters + ---------- + location : EarthLocation + The location of the telescope + """ + if self._location is not None: + self.log.warning( + "Telescope location for DL3 file was already set, replacing current location" + ) + if location is not None and not isinstance(location, EarthLocation): + raise TypeError("location must be an astropy EarthLocation.") + self._location = location + + @property + def livetime_fraction(self) -> float: + return self._livetime_fraction + + @livetime_fraction.setter + def livetime_fraction(self, livetime_fraction: float): + """ + Parameters + ---------- + livetime_fraction : float + The livetime fraction for the observations (DEADC correction factor) + """ + if self.livetime_fraction is not None: + self.log.warning( + "Livetime fraction for DL3 file was already set, replacing current livetime fraction" + ) + + if livetime_fraction is None: + self._livetime_fraction = None + return + + if isinstance(livetime_fraction, (bool, np.bool_)) or ( + not np.isscalar(livetime_fraction) or not np.isreal(livetime_fraction) + ): + raise TypeError("livetime_fraction must be a real scalar.") + if not np.isfinite(livetime_fraction) or (not 0.0 <= livetime_fraction <= 1.0): + raise ValueError("livetime_fraction must be in the range [0, 1].") + + self._livetime_fraction = livetime_fraction + + @property + def telescope_information(self) -> Dict[str, Any]: + return self._telescope_information + + @telescope_information.setter + def telescope_information(self, telescope_information: Dict[str, Any]): + """ + Parameters + ---------- + telescope_information : dict[str, any] + A dictionary containing general information about telescope with as key : organisation, array, subarray, telescope_list + """ + if self._telescope_information is not None: + self.log.warning( + "Telescope information for DL3 file was already set, replacing current information" + ) + if telescope_information is not None: + if not isinstance(telescope_information, Mapping): + raise TypeError("telescope_information must be a mapping.") + required = {"organisation", "array", "subarray", "telescope_list"} + missing = required - set(telescope_information) + if missing: + raise ValueError( + "telescope_information is missing keys: " + + ", ".join(sorted(missing)) + ) + self._telescope_information = telescope_information + + @property + def target_information(self) -> Dict[str, Any]: + return self._target_information + + @target_information.setter + def target_information(self, target_information: Dict[str, Any]): + """ + Parameters + ---------- + target_information : dict[str, any] + A dictionary containing general information about the targeted source with as key : observer, object_name, object_coordinate + """ + if self._target_information is not None: + self.log.warning( + "Target information for DL3 file was already set, replacing current target information" + ) + if target_information is not None: + if not isinstance(target_information, Mapping): + raise TypeError("target_information must be a mapping.") + required = {"observer", "object_name", "object_coordinate"} + missing = required - set(target_information) + if missing: + raise ValueError( + "target_information is missing keys: " + ", ".join(sorted(missing)) + ) + + coordinate = target_information["object_coordinate"] + if not isinstance(coordinate, (SkyCoord, BaseCoordinateFrame)): + raise TypeError( + "target_information['object_coordinate'] must be a SkyCoord or coordinate frame." + ) + self._target_information = target_information + + @property + def software_information(self) -> Dict[str, Any]: + return self._software_information + + @software_information.setter + def software_information(self, software_information: Dict[str, Any]): + """ + Parameters + ---------- + software_information : dict[str, any] + A dictionary containing general information about the software used to produce the file with as key : analysis_version, calibration_version, dst_version + """ + if self._software_information is not None: + self.log.warning( + "Software information for DL3 file was already set, replacing current software information" + ) + if software_information is not None: + if not isinstance(software_information, Mapping): + raise TypeError("software_information must be a mapping.") + required = {"analysis_version", "calibration_version", "dst_version"} + missing = required - set(software_information) + if missing: + raise ValueError( + "software_information is missing keys: " + + ", ".join(sorted(missing)) + ) + self._software_information = software_information + + +class DL3GADFEventsWriter(DL3EventsWriter): + """ + Class to write DL3 in GADF format, subclass of DL3_Format + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.file_creation_time = datetime.now(tz=UTC) + self._reference_time = self.reference_time.tai + + def _to_tai_time(self, value: Any, value_name: str) -> Time: + """ + Normalize input to an absolute TAI ``Time`` object. + + Parameters + ---------- + value : Any + Input time-like value. Supported types are ``Time``, ``TimeDelta``, + time ``Quantity`` and scalar numeric values interpreted as seconds + relative to ``reference_time``. + value_name : str + Name of the value used in error messages. + """ + if isinstance(value, Time): + return value.tai + + if isinstance(value, TimeDelta): + return self._reference_time + value + + if isinstance(value, u.Quantity): + if not value.unit.is_equivalent(u.s): + raise ValueError( + f"{value_name} must be a time quantity equivalent to seconds." + ) + return self._reference_time + TimeDelta(value.to(u.s)) + + if np.isscalar(value) and np.isreal(value): + return self._reference_time + TimeDelta(float(value) * u.s) + + raise TypeError( + f"{value_name} must be Time, TimeDelta, a time Quantity, or a scalar number of seconds." + ) + + def _to_relative_time_seconds(self, value: Any, value_name: str) -> Any: + """ + Normalize input to seconds relative to ``reference_time``. + + Parameters + ---------- + value : Any + Input time-like value. Supported types are ``Time``, ``TimeDelta``, + time ``Quantity`` and numeric values assumed to already be in seconds. + value_name : str + Name of the value used in error messages. + """ + if isinstance(value, Time): + return (value.tai - self._reference_time).to_value(u.s) + + if isinstance(value, TimeDelta): + return value.to_value(u.s) + + if isinstance(value, u.Quantity): + if not value.unit.is_equivalent(u.s): + raise ValueError( + f"{value_name} must be a time quantity equivalent to seconds." + ) + return value.to_value(u.s) + + values = np.asarray(value) + if np.issubdtype(values.dtype, np.number): + return values.astype(np.float64, copy=False) + + raise TypeError( + f"{value_name} must be Time, TimeDelta, a time Quantity, or numeric seconds." + ) + + def _to_relative_time_quantity(self, value: Any, value_name: str) -> u.Quantity: + """Normalize input to a quantity in seconds relative to ``reference_time``.""" + return u.Quantity( + self._to_relative_time_seconds(value, value_name), + u.s, + copy=COPY_IF_NEEDED, + ) + + @staticmethod + def _circular_interp(x, xp, fp_deg): + """ + Interpolate angular values in degrees, handling the 0/360 wrap-around. + + Uses ``np.unwrap`` to remove discontinuities before interpolation, + then wraps the result back to [0, 360). + + Parameters + ---------- + x : array-like + The x-coordinates at which to evaluate the interpolated values. + xp : array-like + The x-coordinates of the data points (must be increasing). + fp_deg : array-like + The y-coordinates of the data points, in degrees. + + Returns + ------- + np.ndarray + Interpolated angular values in degrees, in [0, 360). + """ + fp_rad = np.deg2rad(np.asarray(fp_deg, dtype=float)) + fp_unwrapped = np.unwrap(fp_rad) + result_rad = np.interp(x, xp, fp_unwrapped) + return np.rad2deg(result_rad) % 360 + + @staticmethod + def _circular_mean(angles_deg): + """ + Compute the mean of angular values in degrees, handling the 0/360 + wrap-around. + + Uses the ``atan2(mean(sin), mean(cos))`` formula for circular + statistics. + + Parameters + ---------- + angles_deg : array-like + Angular values in degrees. + + Returns + ------- + float + Mean angle in degrees, in [0, 360). + """ + angles_rad = np.deg2rad(np.asarray(angles_deg, dtype=float)) + return float( + np.rad2deg( + np.arctan2( + np.mean(np.sin(angles_rad)), + np.mean(np.cos(angles_rad)), + ) + ) + % 360 + ) + + def write_file(self, path): + """ + This function will write the new DL3 file + All the content associated with the file should have been specified previously, otherwise error will be raised + + Parameters + ---------- + path : str + The full path and filename of the new file to write + """ + self.file_creation_time = datetime.now(tz=UTC) + + hdu_dl3 = fits.HDUList( + [fits.PrimaryHDU(header=Header(self.get_hdu_header_base_format()))] + ) + hdu_dl3.append( + fits.BinTableHDU( + data=self.transform_events_columns_for_gadf_format(self.events), + name="EVENTS", + header=Header(self.get_hdu_header_events()), + ) + ) + hdu_dl3.append( + fits.BinTableHDU( + data=self.create_gti_table(), + name="GTI", + header=Header(self.get_hdu_header_gti()), + ) + ) + hdu_dl3.append( + fits.BinTableHDU( + data=self.create_pointing_table(), + name="POINTING", + header=Header(self.get_hdu_header_pointing()), + ) + ) + + if self.aeff is None: + raise ValueError("Missing effective area IRF") + hdu_dl3.append(self.aeff) + hdu_dl3[-1].header["OBS_ID"] = self.obs_id + if self.psf is None: + raise ValueError("Missing PSF IRF") + hdu_dl3.append(self.psf) + hdu_dl3[-1].header["OBS_ID"] = self.obs_id + if self.edisp is None: + raise ValueError("Missing EDISP IRF") + hdu_dl3.append(self.edisp) + hdu_dl3[-1].header["OBS_ID"] = self.obs_id + if self.bkg is not None: + hdu_dl3.append(self.bkg) + hdu_dl3[-1].header["OBS_ID"] = self.obs_id + + hdu_dl3.writeto(path, checksum=True, overwrite=self.overwrite) + + def get_hdu_header_base_format(self) -> Dict[str, Any]: + """ + Return the base information that should be included in all HDU of the final fits file + """ + return { + "HDUCLASS": "GADF", + "HDUVERS": "v0.3", + "HDUDOC": "https://gamma-astro-data-formats.readthedocs.io/en/v0.3/index.html", + "CREATOR": "ctapipe " + ctapipe_version, + "CREATED": self.file_creation_time.isoformat(), + } + + def get_hdu_header_time_reference(self) -> Dict[str, Any]: + """ + Return the time reference keywords needed to interpret TIME columns. + + These keywords (MJDREFI, MJDREFF, TIMEUNIT, TIMESYS, TIMEREF) should + be present in every HDU that contains a TIME column or time-related + header values. + """ + return { + "MJDREFI": int(self._reference_time.mjd), + "MJDREFF": self._reference_time.mjd % 1, + "TIMEUNIT": "s", + "TIMEREF": "TOPOCENTER", + "TIMESYS": "TAI", + } + + @lru_cache(maxsize=1) + def get_hdu_header_base_time(self) -> Dict[str, Any]: + """ + Return the information about time parameters used in several HDU + """ + if self.gti is None: + raise ValueError("No available time information for the DL3 file") + if self.livetime_fraction is None: + raise ValueError("No available livetime fraction for the DL3 file") + + start_time = None + stop_time = None + ontime = TimeDelta(0.0 * u.s) + for i, gti_interval in enumerate(self.gti): + interval_start = self._to_tai_time(gti_interval[0], f"gti[{i}].start") + interval_stop = self._to_tai_time(gti_interval[1], f"gti[{i}].stop") + if interval_stop < interval_start: + raise ValueError( + f"Invalid GTI interval at index {i}: stop time is before start time." + ) + + ontime += interval_stop - interval_start + start_time = ( + interval_start + if start_time is None + else min(start_time, interval_start) + ) + stop_time = ( + interval_stop if stop_time is None else max(stop_time, interval_stop) + ) + + header = self.get_hdu_header_time_reference() + header.update( + { + "TSTART": self._to_relative_time_seconds( + start_time, "observation start" + ), + "TSTOP": self._to_relative_time_seconds(stop_time, "observation stop"), + "ONTIME": ontime.to_value(u.s), + "LIVETIME": ontime.to_value(u.s) * self.livetime_fraction, + "DEADC": self.livetime_fraction, + "TELAPSE": (stop_time - start_time).to_value(u.s), + "DATE-OBS": start_time.fits, + "DATE-BEG": start_time.fits, + "DATE-AVG": (start_time + (stop_time - start_time) / 2.0).fits, + "DATE-END": stop_time.fits, + } + ) + return header + + def get_hdu_header_base_observation_information( + self, obs_id_only: bool = False + ) -> Dict[str, Any]: + """ + Return generic information on the observation setting (id, target, ...) + + Parameters + ---------- + obs_id_only : bool + If true, will return a dict with as only information the obs_id + """ + if self.obs_id is None: + raise ValueError("Observation ID is missing.") + header = {"OBS_ID": self.obs_id} + if self.target_information is not None and not obs_id_only: + header["OBSERVER"] = self.target_information["observer"] + header["OBJECT"] = self.target_information["object_name"] + object_coordinate = self.target_information[ + "object_coordinate" + ].transform_to(ICRS()) + if not np.isnan(object_coordinate.ra.to_value(u.deg)): + header["RA_OBJ"] = object_coordinate.ra.to_value(u.deg) + if not np.isnan(object_coordinate.dec.to_value(u.deg)): + header["DEC_OBJ"] = object_coordinate.dec.to_value(u.deg) + return header + + def get_hdu_header_base_subarray_information(self) -> Dict[str, Any]: + """ + Return generic information on the array used for observations + """ + if self.telescope_information is None: + raise ValueError("Telescope information are missing.") + header = { + "ORIGIN": self.telescope_information["organisation"], + "TELESCOP": self.telescope_information["array"], + "INSTRUME": self.telescope_information["subarray"], + "TELLIST": str(self.telescope_information["telescope_list"]), + "N_TELS": len(self.telescope_information["telescope_list"]), + } + return header + + def get_hdu_header_base_software_information(self) -> Dict[str, Any]: + """ + Return information about the software versions used to process the observation + """ + header = {} + if self.software_information is not None: + header["DST_VER"] = self.software_information["dst_version"] + header["ANA_VER"] = self.software_information["analysis_version"] + header["CAL_VER"] = self.software_information["calibration_version"] + return header + + @lru_cache(maxsize=1) + def get_hdu_header_base_pointing(self) -> Dict[str, Any]: + """ + Return information on the pointing during the observation + """ + if self.pointing is None: + raise ValueError("Pointing information are missing") + if self.pointing_mode is None: + raise ValueError("Pointing mode is missing") + if self.location is None: + raise ValueError("Telescope location information are missing") + + gti_table = self.create_gti_table() + delta_time_evaluation = [] + for i in range(len(gti_table)): + delta_time_evaluation += list( + np.linspace(gti_table["START"][i], gti_table["STOP"][i], 100) + ) + delta_time_evaluation = u.Quantity(delta_time_evaluation) + time_evaluation = self._reference_time + TimeDelta(delta_time_evaluation) + + pointing_table = self.create_pointing_table() + if self.pointing_mode == "TRACK": + obs_mode = "POINTING" + icrs_coordinate = SkyCoord( + ra=self._circular_interp( + delta_time_evaluation, + xp=pointing_table["TIME"], + fp_deg=pointing_table["RA_PNT"], + ), + dec=np.interp( + delta_time_evaluation, + xp=pointing_table["TIME"], + fp=pointing_table["DEC_PNT"], + ), + unit=u.deg, + ) + altaz_coordinate = icrs_coordinate.transform_to( + AltAz(location=self.location, obstime=time_evaluation) + ) + elif self.pointing_mode == "DRIFT": + obs_mode = "DRIFT" + altaz_coordinate = AltAz( + alt=u.Quantity( + np.interp( + delta_time_evaluation, + xp=pointing_table["TIME"], + fp=pointing_table["ALT_PNT"], + ), + u.deg, + copy=COPY_IF_NEEDED, + ), + az=self._circular_interp( + delta_time_evaluation, + xp=pointing_table["TIME"], + fp_deg=pointing_table["AZ_PNT"], + ) + * u.deg, + location=self.location, + obstime=time_evaluation, + ) + icrs_coordinate = altaz_coordinate.transform_to(ICRS()) + else: + raise ValueError("Unknown pointing mode") + + header = { + "RADESYS": "ICRS", + "RADECSYS": "ICRS", + "EQUINOX": 2000.0, + "OBS_MODE": obs_mode, + "RA_PNT": self._circular_mean(icrs_coordinate.ra.to_value(u.deg)), + "DEC_PNT": np.mean(icrs_coordinate.dec.to_value(u.deg)), + "ALT_PNT": np.mean(altaz_coordinate.alt.to_value(u.deg)), + "AZ_PNT": self._circular_mean(altaz_coordinate.az.to_value(u.deg)), + "GEOLON": self.location.lon.to_value(u.deg), + "GEOLAT": self.location.lat.to_value(u.deg), + "ALTITUDE": self.location.height.to_value(u.m), + "OBSGEO-X": self.location.x.to_value(u.m), + "OBSGEO-Y": self.location.y.to_value(u.m), + "OBSGEO-Z": self.location.z.to_value(u.m), + } + return header + + def get_hdu_header_events(self) -> Dict[str, Any]: + """ + The output dictionary contain all the necessary information that should be added to the header of the events HDU + """ + header = self.get_hdu_header_base_format() + header.update({"HDUCLAS1": "EVENTS"}) + header.update(self.get_hdu_header_base_time()) + header.update(self.get_hdu_header_base_pointing()) + header.update(self.get_hdu_header_base_observation_information()) + header.update(self.get_hdu_header_base_subarray_information()) + header.update(self.get_hdu_header_base_software_information()) + return header + + def get_hdu_header_gti(self) -> Dict[str, Any]: + """ + The output dictionary contain all the necessary information that should be added to the header of the GTI HDU + """ + header = self.get_hdu_header_base_format() + header.update({"HDUCLAS1": "GTI"}) + header.update(self.get_hdu_header_base_time()) + header.update( + self.get_hdu_header_base_observation_information(obs_id_only=True) + ) + return header + + def get_hdu_header_pointing(self) -> Dict[str, Any]: + """ + The output dictionary contain all the necessary information that should be added to the header of the pointing HDU + """ + header = self.get_hdu_header_base_format() + header.update({"HDUCLAS1": "POINTING"}) + header.update(self.get_hdu_header_time_reference()) + header.update(self.get_hdu_header_base_pointing()) + header.update( + self.get_hdu_header_base_observation_information(obs_id_only=True) + ) + return header + + def transform_events_columns_for_gadf_format(self, events: QTable) -> QTable: + """ + Return an event table containing only the columns that should be added to the EVENTS HDU + It also rename all the columns to match the name expected in the GADF format + + Parameters + ---------- + events : QTable + The base events table to process + """ + rename_from = ["event_id", "time", "reco_ra", "reco_dec", "reco_energy"] + rename_to = ["EVENT_ID", "TIME", "RA", "DEC", "ENERGY"] + + if self.optional_dl3_columns: + rename_from_optional = [ + "multiplicity", + "reco_glon", + "reco_glat", + "reco_alt", + "reco_az", + "reco_fov_lon", + "reco_fov_lat", + "reco_source_fov_offset", + "reco_source_fov_position_angle", + "gh_score", + "reco_dir_uncert", + "reco_energy_uncert", + "reco_core_x", + "reco_core_y", + "reco_core_uncert", + "reco_h_max", + "reco_h_max_uncert", + "reco_x_max", + "reco_x_max_uncert", + ] + rename_to_optional = [ + "MULTIP", + "GLON", + "GLAT", + "ALT", + "AZ", + "DETX", + "DETY", + "THETA", + "PHI", + "GAMMANESS", + "DIR_ERR", + "ENERGY_ERR", + "COREX", + "COREY", + "CORE_ERR", + "HMAX", + "HMAX_ERR", + "XMAX", + "XMAX_ERR", + ] + + for i, c in enumerate(rename_from_optional): + if c not in events.colnames: + self.log.warning( + f"Optional column {c} is missing from the events table." + ) + if self.raise_error_for_optional: + raise ValueError( + f"Optional column {c} is missing from the events table." + ) + else: + rename_from.append(rename_from_optional[i]) + rename_to.append(rename_to_optional[i]) + + for c in rename_from: + if c not in events.colnames: + raise ValueError( + f"Required column {c} is missing from the events table." + ) + + renamed_events = QTable(events, copy=COPY_IF_NEEDED) + renamed_events["time"] = self._to_relative_time_quantity( + renamed_events["time"], "events.time" + ) + renamed_events.rename_columns(rename_from, rename_to) + renamed_events = renamed_events[rename_to] + return renamed_events + + def create_gti_table(self) -> QTable: + """ + Build a table that contains GTI information with the GADF names and format, to be concerted directly as a TableHDU + """ + table_structure = {"START": [], "STOP": []} + for i, gti_interval in enumerate(self.gti): + interval_start = self._to_tai_time(gti_interval[0], f"gti[{i}].start") + interval_stop = self._to_tai_time(gti_interval[1], f"gti[{i}].stop") + table_structure["START"].append( + self._to_relative_time_quantity(interval_start, f"gti[{i}].start") + ) + table_structure["STOP"].append( + self._to_relative_time_quantity(interval_stop, f"gti[{i}].stop") + ) + + table = QTable(table_structure) + table.sort("START") + for i in range(len(table) - 1): + if table["STOP"][i] > table["START"][i + 1]: + self.log.warning("Overlapping GTI intervals") + break + + return table + + def create_pointing_table(self) -> QTable: + """ + Build a table that contains pointing information with the GADF names and format, to be concerted directly as a TableHDU + """ + if self.pointing is None: + raise ValueError("Pointing information are missing") + if self.location is None: + raise ValueError("Telescope location information are missing") + + table_structure = { + "TIME": [], + "RA_PNT": [], + "DEC_PNT": [], + "ALT_PNT": [], + "AZ_PNT": [], + } + + for i, pointing in enumerate(self.pointing): + time = self._to_tai_time(pointing[0], f"pointing[{i}].time") + pointing_icrs = pointing[1].transform_to(ICRS()) + pointing_altaz = pointing[1].transform_to( + AltAz(location=self.location, obstime=time) + ) + table_structure["TIME"].append( + self._to_relative_time_quantity(time, f"pointing[{i}].time") + ) + table_structure["RA_PNT"].append(pointing_icrs.ra.to(u.deg)) + table_structure["DEC_PNT"].append(pointing_icrs.dec.to(u.deg)) + table_structure["ALT_PNT"].append(pointing_altaz.alt.to(u.deg)) + table_structure["AZ_PNT"].append(pointing_altaz.az.to(u.deg)) + + table = QTable(table_structure) + table.sort("TIME") + return table diff --git a/src/ctapipe/io/tests/test_dl3.py b/src/ctapipe/io/tests/test_dl3.py new file mode 100644 index 00000000000..b727f80aa90 --- /dev/null +++ b/src/ctapipe/io/tests/test_dl3.py @@ -0,0 +1,611 @@ +from datetime import UTC, datetime, timedelta + +import astropy.units as u +import numpy as np +import pytest +from astropy.io import fits +from astropy.table import Column, Table +from astropy.time import Time +from traitlets.config import Config + +from ...irf.cuts import DL2EventSelection +from ...tools.create_dl3 import DL3Tool +from ..dl3 import DL3GADFEventsWriter + + +@pytest.fixture +def hdu_irfs(dummy_irf_file): + with fits.open(dummy_irf_file, checksum=True) as hdus: + yield hdus + + +@pytest.fixture(scope="session") +def dl2_events_for_dl3(single_obs_gamma_diffuse_full_reco_file, dummy_cuts_file): + tool = DL3Tool( + config=Config( + { + "EventPreprocessor": { + "energy_reconstructor": "ExtraTreesRegressor", + "geometry_reconstructor": "HillasReconstructor", + "gammaness_reconstructor": "ExtraTreesClassifier", + } + } + ) + ) + tool.dl2_file = single_obs_gamma_diffuse_full_reco_file + tool.event_selection = DL2EventSelection(parent=tool, cuts_file=dummy_cuts_file) + tool._configure_event_preprocessor_for_dl3() + + events = tool._load_preselected_events(chunk_size=1000) + events = tool.event_selection.calculate_gamma_selection( + events, + apply_spatial_selection=False, + ) + events = events[events["selected_gamma"]] + + meta = tool._get_observation_information() + events["time"] = ( + Time("2020-01-01T00:00:00", scale="tai") + np.arange(len(events)) * u.s + ) + tool.output_table_schema = [ + Column(name="reco_az", unit=u.deg), + Column(name="reco_alt", unit=u.deg), + Column(name="pointing_az", unit=u.deg), + Column(name="pointing_alt", unit=u.deg), + Column(name="time"), + Column(name="reco_ra", unit=u.deg), + Column(name="reco_dec", unit=u.deg), + ] + events = tool._make_derived_columns(events, location_subarray=meta["location"]) + return events + + +@pytest.fixture(scope="session") +def dl2_meta_for_dl3(single_obs_gamma_diffuse_full_reco_file, dummy_cuts_file): + tool = DL3Tool( + config=Config( + { + "EventPreprocessor": { + "energy_reconstructor": "ExtraTreesRegressor", + "geometry_reconstructor": "HillasReconstructor", + "gammaness_reconstructor": "ExtraTreesClassifier", + } + } + ) + ) + tool.dl2_file = single_obs_gamma_diffuse_full_reco_file + tool.event_selection = DL2EventSelection(parent=tool, cuts_file=dummy_cuts_file) + return tool._get_observation_information() + + +@pytest.fixture +def dl3_writer(dl2_events_for_dl3, dl2_meta_for_dl3, hdu_irfs): + dl3_format_optional = DL3GADFEventsWriter() + + # Load events + dl3_format_optional.events = dl2_events_for_dl3 + + # Load metadata + dl3_format_optional.obs_id = dl2_meta_for_dl3["obs_id"] + dl3_format_optional.pointing = dl2_meta_for_dl3["pointing"]["pointing_list"] + dl3_format_optional.pointing_mode = dl2_meta_for_dl3["pointing"]["pointing_mode"] + dl3_format_optional.gti = dl2_meta_for_dl3["gti"] + dl3_format_optional.livetime_fraction = dl2_meta_for_dl3["livetime_fraction"] + dl3_format_optional.location = dl2_meta_for_dl3["location"] + dl3_format_optional.telescope_information = dl2_meta_for_dl3[ + "telescope_information" + ] + dl3_format_optional.target_information = dl2_meta_for_dl3["target"] + dl3_format_optional.software_information = dl2_meta_for_dl3["software_version"] + + # Load IRFs + for i in range(1, len(hdu_irfs)): + if "HDUCLAS2" in hdu_irfs[i].header.keys(): + if hdu_irfs[i].header["HDUCLAS2"] == "EFF_AREA": + if dl3_format_optional.aeff is None: + dl3_format_optional.aeff = hdu_irfs[i] + elif "EXTNAME" in hdu_irfs[i].header and not ( + "PROTONS" in hdu_irfs[i].header["EXTNAME"] + or "ELECTRONS" in hdu_irfs[i].header["EXTNAME"] + ): + dl3_format_optional.aeff = hdu_irfs[i] + elif hdu_irfs[i].header["HDUCLAS2"] == "EDISP": + dl3_format_optional.edisp = hdu_irfs[i] + elif hdu_irfs[i].header["HDUCLAS2"] == "PSF": + dl3_format_optional.psf = hdu_irfs[i] + elif hdu_irfs[i].header["HDUCLAS2"] == "BKG": + dl3_format_optional.bkg = hdu_irfs[i] + return dl3_format_optional + + +class TestDL3GADFEventsWriter: + def test_dl3_file(self, tmp_path, dl3_writer): + output_path = tmp_path / "dl3_gadf.fits" + + dl3_writer.write_file(output_path) + + with fits.open(output_path, checksum=True) as hdul: + assert isinstance(hdul[0], fits.PrimaryHDU) + + names = [hdu.name for hdu in hdul] + assert "EVENTS" in names + assert "GTI" in names + assert "POINTING" in names + + irf_kinds = { + hdu.header.get("HDUCLAS2") + for hdu in hdul[1:] + if "HDUCLAS2" in hdu.header + } + assert {"EFF_AREA", "EDISP", "PSF", "BKG"}.issubset(irf_kinds) + + for hdu in hdul: + if "OBS_ID" in hdu.header: + assert hdu.header["OBS_ID"] == dl3_writer.obs_id + + def test_dl3_file_missing_aeff(self, tmp_path, dl3_writer): + output_path = tmp_path / "dl3_gadf_aeff.fits" + + dl3_writer._aeff = None + with pytest.raises(ValueError): + dl3_writer.write_file(output_path) + + def test_dl3_file_missing_edisp(self, tmp_path, dl3_writer): + output_path = tmp_path / "dl3_gadf_edisp.fits" + + dl3_writer._edisp = None + with pytest.raises(ValueError): + dl3_writer.write_file(output_path) + + def test_dl3_file_missing_psf(self, tmp_path, dl3_writer): + output_path = tmp_path / "dl3_gadf_psf.fits" + + dl3_writer._psf = None + with pytest.raises(ValueError): + dl3_writer.write_file(output_path) + + def test_dl3_file_overwrite(self, tmp_path, dl3_writer): + output_path = tmp_path / "dl3_gadf_overwrite.fits" + + dl3_writer.write_file(output_path) + with pytest.raises(OSError): + dl3_writer.write_file(output_path) + + def test_hdu_header_base(self, dl3_writer): + header = dl3_writer.get_hdu_header_base_format() + + assert header["HDUCLASS"] == "GADF" + assert header["HDUVERS"] == "v0.3" + assert header["CREATOR"].startswith("ctapipe") + + file_time = datetime.fromisoformat(header["CREATED"]) + assert (datetime.now(UTC) - file_time) < timedelta(hours=1) + + def test_hdu_header_time(self, dl3_writer): + header = dl3_writer.get_hdu_header_base_time() + + for key in [ + "MJDREFI", + "MJDREFF", + "TIMEUNIT", + "TIMEREF", + "TIMESYS", + "TSTART", + "TSTOP", + "ONTIME", + "LIVETIME", + "DEADC", + "TELAPSE", + "DATE-OBS", + "DATE-BEG", + "DATE-AVG", + "DATE-END", + ]: + assert key in header + + assert isinstance(header["MJDREFI"], int) + assert header["MJDREFI"] == 58119 + assert isinstance(header["MJDREFF"], float) + assert 0.0 <= header["MJDREFF"] < 1.0 + assert header["TIMEREF"] == "TOPOCENTER" + assert header["TIMESYS"] == "TAI" + assert header["TIMEUNIT"] == "s" + + assert header["TSTOP"] > header["TSTART"] + + assert header["DEADC"] <= 1 + assert header["LIVETIME"] == pytest.approx(header["ONTIME"] * header["DEADC"]) + assert header["LIVETIME"] <= header["TELAPSE"] + assert header["TELAPSE"] == pytest.approx(header["TSTOP"] - header["TSTART"]) + + ref_mjd = header["MJDREFI"] + header["MJDREFF"] + tref = Time(ref_mjd, format="mjd", scale="tai") + tstart = Time(header["DATE-BEG"], format="fits", scale="tai") + tavg = Time(header["DATE-AVG"], format="fits", scale="tai") + tstop = Time(header["DATE-END"], format="fits", scale="tai") + assert (tstart - tref).to_value(u.s) == pytest.approx( + header["TSTART"], rel=1e-6 + ) + assert (tstop - tref).to_value(u.s) == pytest.approx(header["TSTOP"], rel=1e-6) + assert (tavg >= tstart) & (tavg <= tstop) + + def test_hdu_header_time_missing_gti(self, dl3_writer): + dl3_writer._gti = None + with pytest.raises(ValueError): + dl3_writer.get_hdu_header_base_time() + + def test_hdu_header_time_missing_deadtime(self, dl3_writer): + dl3_writer._livetime_fraction = None + with pytest.raises(ValueError): + dl3_writer.get_hdu_header_base_time() + + def test_livetime_fraction_setter_validation(self, dl3_writer): + dl3_writer.livetime_fraction = 0.0 + assert dl3_writer.livetime_fraction == 0.0 + + dl3_writer.livetime_fraction = 1.0 + assert dl3_writer.livetime_fraction == 1.0 + + for invalid in (-1e-3, 1.001, np.nan, np.inf, -np.inf): + with pytest.raises(ValueError): + dl3_writer.livetime_fraction = invalid + + for invalid in ([0.5], "0.5", True): + with pytest.raises(TypeError): + dl3_writer.livetime_fraction = invalid + + dl3_writer.livetime_fraction = None + assert dl3_writer.livetime_fraction is None + + def test_obs_id_setter_validation(self, dl3_writer): + dl3_writer.obs_id = np.int64(1234) + assert dl3_writer.obs_id == 1234 + + with pytest.raises(ValueError): + dl3_writer.obs_id = -1 + + for invalid in (1.2, "1", True): + with pytest.raises(TypeError): + dl3_writer.obs_id = invalid + + dl3_writer.obs_id = None + assert dl3_writer.obs_id is None + + def test_events_setter_validation(self, dl3_writer): + table = Table(dl3_writer.events, copy=False) + dl3_writer.events = table + assert dl3_writer.events is table + + with pytest.raises(TypeError): + dl3_writer.events = {"not": "a table"} + + dl3_writer.events = None + assert dl3_writer.events is None + + def test_pointing_setter_validation(self, dl3_writer): + with pytest.raises(TypeError): + dl3_writer.pointing = "not-a-sequence" + + with pytest.raises(ValueError): + dl3_writer.pointing = [(Time("2020-01-01T00:00:00", scale="tai"),)] + + with pytest.raises(TypeError): + dl3_writer.pointing = [(Time("2020-01-01T00:00:00", scale="tai"), object())] + + dl3_writer.pointing = None + assert dl3_writer.pointing is None + + def test_pointing_mode_setter_validation(self, dl3_writer): + dl3_writer.pointing_mode = "track" + assert dl3_writer.pointing_mode == "TRACK" + + dl3_writer.pointing_mode = " drift " + assert dl3_writer.pointing_mode == "DRIFT" + + with pytest.raises(TypeError): + dl3_writer.pointing_mode = 1 + + with pytest.raises(ValueError): + dl3_writer.pointing_mode = "WOBBLE" + + def test_gti_setter_validation(self, dl3_writer): + with pytest.raises(TypeError): + dl3_writer.gti = "not-a-sequence" + + with pytest.raises(ValueError): + dl3_writer.gti = [(Time("2020-01-01T00:00:00", scale="tai"),)] + + dl3_writer.gti = None + assert dl3_writer.gti is None + + def test_location_setter_validation(self, dl3_writer): + with pytest.raises(TypeError): + dl3_writer.location = "not-a-location" + + dl3_writer.location = None + assert dl3_writer.location is None + + @pytest.mark.parametrize("setter", ["aeff", "psf", "edisp", "bkg"]) + def test_irf_setter_validation(self, dl3_writer, setter): + with pytest.raises(TypeError): + setattr(dl3_writer, setter, "not-an-hdu") + + def test_telescope_information_setter_validation(self, dl3_writer): + with pytest.raises(TypeError): + dl3_writer.telescope_information = "not-a-mapping" + + with pytest.raises(ValueError, match="missing keys"): + dl3_writer.telescope_information = {"organisation": "CTAO"} + + def test_target_information_setter_validation(self, dl3_writer): + with pytest.raises(TypeError): + dl3_writer.target_information = "not-a-mapping" + + with pytest.raises(ValueError, match="missing keys"): + dl3_writer.target_information = {"observer": "UNKNOWN"} + + with pytest.raises(TypeError): + dl3_writer.target_information = { + "observer": "UNKNOWN", + "object_name": "UNKNOWN", + "object_coordinate": object(), + } + + def test_software_information_setter_validation(self, dl3_writer): + with pytest.raises(TypeError): + dl3_writer.software_information = "not-a-mapping" + + with pytest.raises(ValueError, match="missing keys"): + dl3_writer.software_information = {"analysis_version": "ctapipe X"} + + def test_hdu_header_obs_info(self, dl3_writer, dl2_meta_for_dl3): + obs_only = dl3_writer.get_hdu_header_base_observation_information( + obs_id_only=True + ) + assert obs_only["OBS_ID"] == dl3_writer.obs_id + assert len(obs_only) == 1 + + full_header = dl3_writer.get_hdu_header_base_observation_information( + obs_id_only=False + ) + assert full_header["OBS_ID"] == dl3_writer.obs_id + target = dl2_meta_for_dl3["target"] + assert full_header["OBSERVER"] == target["observer"] + assert full_header["OBJECT"] == target["object_name"] + + def test_hdu_header_obs_info_missing_obs_id(self, dl3_writer): + dl3_writer._obs_id = None + with pytest.raises(ValueError): + dl3_writer.get_hdu_header_base_observation_information(obs_id_only=True) + with pytest.raises(ValueError): + dl3_writer.get_hdu_header_base_observation_information(obs_id_only=False) + + def test_hdu_header_subarray_info(self, dl3_writer, dl2_meta_for_dl3): + header = dl3_writer.get_hdu_header_base_subarray_information() + + tel_info = dl2_meta_for_dl3["telescope_information"] + assert header["ORIGIN"] == tel_info["organisation"] + assert header["TELESCOP"] == tel_info["array"] + assert header["INSTRUME"] == tel_info["subarray"] + assert header["TELLIST"] == str(tel_info["telescope_list"]) + assert header["N_TELS"] == len(tel_info["telescope_list"]) + + def test_hdu_header_software_info(self, dl3_writer, dl2_meta_for_dl3): + header = dl3_writer.get_hdu_header_base_software_information() + soft = dl2_meta_for_dl3["software_version"] + assert header["DST_VER"] == soft["dst_version"] + assert header["ANA_VER"] == soft["analysis_version"] + assert header["CAL_VER"] == soft["calibration_version"] + + dl3_writer._software_information = None + header = dl3_writer.get_hdu_header_base_software_information() + assert len(header) == 0 + + def test_hdu_header_pointing(self, dl3_writer, dl2_meta_for_dl3): + header = dl3_writer.get_hdu_header_base_pointing() + + assert header["RADESYS"] == "ICRS" + assert header["RADECSYS"] == "ICRS" + assert header["EQUINOX"] == 2000.0 + assert header["OBS_MODE"] == dl2_meta_for_dl3["pointing"]["pointing_mode"] + + for key in ["RA_PNT", "DEC_PNT", "ALT_PNT", "AZ_PNT"]: + assert np.isfinite(header[key]) + + loc = dl2_meta_for_dl3["location"] + assert header["GEOLON"] == pytest.approx(loc.lon.to_value(u.deg)) + assert header["GEOLAT"] == pytest.approx(loc.lat.to_value(u.deg)) + assert header["ALTITUDE"] == pytest.approx(loc.height.to_value(u.m)) + assert header["OBSGEO-X"] == pytest.approx(loc.x.to_value(u.m)) + assert header["OBSGEO-Y"] == pytest.approx(loc.y.to_value(u.m)) + assert header["OBSGEO-Z"] == pytest.approx(loc.z.to_value(u.m)) + + def test_hdu_header_pointing_track_mode_regression(self, dl3_writer): + """Regression: TRACK mode must not fail when interpolating RA around 0/360.""" + dl3_writer.pointing_mode = "TRACK" + header = dl3_writer.get_hdu_header_base_pointing() + + assert header["OBS_MODE"] == "POINTING" + for key in ["RA_PNT", "DEC_PNT", "ALT_PNT", "AZ_PNT"]: + assert np.isfinite(header[key]) + + def test_hdu_header_pointing_drift_mode_regression(self, dl3_writer): + """Regression: DRIFT mode interpolation must provide valid angular quantities.""" + dl3_writer.pointing_mode = "DRIFT" + header = dl3_writer.get_hdu_header_base_pointing() + + assert header["OBS_MODE"] == "DRIFT" + for key in ["RA_PNT", "DEC_PNT", "ALT_PNT", "AZ_PNT"]: + assert np.isfinite(header[key]) + + def test_hdu_header_pointing_missing_pointing(self, dl3_writer): + dl3_writer._pointing = None + with pytest.raises(ValueError): + dl3_writer.get_hdu_header_base_pointing() + + def test_hdu_header_pointing_missing_pointing_mode(self, dl3_writer): + dl3_writer._pointing_mode = None + with pytest.raises(ValueError): + dl3_writer.get_hdu_header_base_pointing() + + def test_hdu_header_pointing_missing_location(self, dl3_writer): + dl3_writer._location = None + with pytest.raises(ValueError): + dl3_writer.get_hdu_header_base_pointing() + + def test_hdu_header_events_hdu(self, dl3_writer): + header = dl3_writer.get_hdu_header_events() + + assert header["HDUCLASS"] == "GADF" + assert header["HDUCLAS1"] == "EVENTS" + # some representative keys from the different helper headers + for key in [ + "HDUCLASS", + "HDUDOC", + "HDUVERS", + "HDUCLAS1", + "OBS_ID", + "TSTART", + "TSTOP", + "ONTIME", + "LIVETIME", + "DEADC", + "OBS_MODE", + "RA_PNT", + "DEC_PNT", + "ALT_PNT", + "AZ_PNT", + "RADESYS", + "RADECSYS", + "EQUINOX", + "ORIGIN", + "TELESCOP", + "INSTRUME", + "CREATOR", + ]: + assert key in header + + def test_hdu_header_gti_hdu(self, dl3_writer): + header = dl3_writer.get_hdu_header_gti() + + for key in [ + "MJDREFI", + "MJDREFF", + "TIMEUNIT", + "TIMEREF", + "TIMESYS", + "TSTART", + "TSTOP", + "ONTIME", + "LIVETIME", + "TELAPSE", + "DATE-OBS", + "DATE-BEG", + "DATE-AVG", + "DATE-END", + ]: + assert key in header + + assert header["HDUCLASS"] == "GADF" + assert header["HDUCLAS1"] == "GTI" + + def test_hdu_header_pointing_hdu(self, dl3_writer): + header = dl3_writer.get_hdu_header_pointing() + + assert header["HDUCLASS"] == "GADF" + assert header["HDUCLAS1"] == "POINTING" + assert "TSTART" not in header + assert "TSTOP" not in header + # Time reference keywords must be present so the TIME column + # can be interpreted correctly. + for key in ["MJDREFI", "MJDREFF", "TIMEUNIT", "TIMESYS", "TIMEREF"]: + assert key in header, f"Time reference keyword {key} missing from POINTING" + assert header["TIMEREF"] == "TOPOCENTER" + assert header["TIMESYS"] == "TAI" + assert header["TIMEUNIT"] == "s" + for key in ["RA_PNT", "DEC_PNT", "ALT_PNT", "AZ_PNT", "OBS_ID"]: + assert key in header + + def test_column_renaming(self, dl3_writer): + events = dl3_writer.events + renamed = dl3_writer.transform_events_columns_for_gadf_format(events) + + assert renamed.colnames == ["EVENT_ID", "TIME", "RA", "DEC", "ENERGY"] + assert len(renamed) == len(events) + assert renamed["TIME"].unit.is_equivalent(u.s) + assert renamed["TIME"].ndim == 1 + assert renamed["TIME"].dtype.kind == "f" + assert np.all(np.isfinite(renamed["TIME"])) + if len(renamed) > 1: + np.testing.assert_allclose( + np.diff(renamed["TIME"].to_value(u.s)), + 1.0, + ) + + bad_events = events.copy() + bad_events.remove_column("reco_energy") + with pytest.raises(ValueError, match="Required column reco_energy is missing"): + dl3_writer.transform_events_columns_for_gadf_format(bad_events) + + def test_gti_table(self, dl3_writer, dl2_meta_for_dl3): + gti_table = dl3_writer.create_gti_table() + + assert gti_table.colnames == ["START", "STOP"] + assert len(gti_table) == len(dl2_meta_for_dl3["gti"]) + + def test_pointing_table(self, dl3_writer): + pointing_table = dl3_writer.create_pointing_table() + + assert pointing_table.colnames == [ + "TIME", + "RA_PNT", + "DEC_PNT", + "ALT_PNT", + "AZ_PNT", + ] + assert len(pointing_table) >= 1 + + times = pointing_table["TIME"].to_value(u.s) + assert np.all(np.diff(times) >= 0) + + assert np.all( + (-90.0 <= pointing_table["DEC_PNT"].to_value(u.deg)) + & (pointing_table["DEC_PNT"].to_value(u.deg) <= 90.0) + ) + assert np.all( + (-90.0 <= pointing_table["ALT_PNT"].to_value(u.deg)) + & (pointing_table["ALT_PNT"].to_value(u.deg) <= 90.0) + ) + assert np.all(np.isfinite(pointing_table["RA_PNT"].to_value(u.deg))) + + def test_pointing_table_missing_pointing(self, dl3_writer): + dl3_writer._pointing = None + with pytest.raises(ValueError): + dl3_writer.create_pointing_table() + + def test_pointing_table_missing_location(self, dl3_writer): + dl3_writer._location = None + with pytest.raises(ValueError): + dl3_writer.create_pointing_table() + + def test_gti_table_is_sorted(self, dl3_writer, dl2_meta_for_dl3): + """Regression test: GTI table must be sorted by START (bug #1.3).""" + original_gti = dl3_writer.gti + + # Build GTI intervals in reverse chronological order + ref = Time("2020-06-01T00:00:00", scale="tai") + reversed_gti = [ + (ref + 200 * u.s, ref + 300 * u.s), + (ref + 100 * u.s, ref + 200 * u.s), + (ref + 0 * u.s, ref + 100 * u.s), + ] + dl3_writer.gti = reversed_gti + + gti_table = dl3_writer.create_gti_table() + start_values = gti_table["START"].to_value(u.s) + assert np.all(np.diff(start_values) >= 0), ( + "GTI START column must be sorted in ascending order" + ) + + # Restore original GTI + dl3_writer.gti = original_gti From 675619c4a62da1a26d82ac902952f6f3b8c3bcbd Mon Sep 17 00:00:00 2001 From: Mathieu de Bony Date: Thu, 2 Apr 2026 12:08:37 +0200 Subject: [PATCH 2/7] Adapt to context of dedicated branch --- src/ctapipe/conftest.py | 82 ++++++++++ src/ctapipe/io/dl3.py | 262 +++++++++++++++---------------- src/ctapipe/io/tests/test_dl3.py | 180 ++++++++++++++------- 3 files changed, 340 insertions(+), 184 deletions(-) diff --git a/src/ctapipe/conftest.py b/src/ctapipe/conftest.py index dbdaae58f54..ce80c6bead2 100644 --- a/src/ctapipe/conftest.py +++ b/src/ctapipe/conftest.py @@ -871,6 +871,40 @@ def gamma_diffuse_full_reco_file( return output_path +@pytest.fixture(scope="session") +def single_obs_gamma_diffuse_full_reco_file(gamma_diffuse_full_reco_file, irf_tmp_path): + """ + Copy of gamma_diffuse_full_reco_file restricted to its first observation block. + + The full multi-observation file cannot be used for DL3 production, which + requires a single obs_id per output file. + """ + output_path = irf_tmp_path / "gamma_diffuse_single_obs.dl2.h5" + shutil.copy(gamma_diffuse_full_reco_file, output_path) + + obs_table = read_table(output_path, "/configuration/observation/observation_block") + first_obs_id = obs_table["obs_id"][0] + single_obs = obs_table[obs_table["obs_id"] == first_obs_id] + single_obs["actual_duration"] = 1800.0 * u.s + + sched_table = read_table(output_path, "/configuration/observation/scheduling_block") + single_sched = sched_table[sched_table["sb_id"] == single_obs["sb_id"][0]] + + write_table( + single_obs, + output_path, + "/configuration/observation/observation_block", + overwrite=True, + ) + write_table( + single_sched, + output_path, + "/configuration/observation/scheduling_block", + overwrite=True, + ) + return output_path + + @pytest.fixture(scope="session") def proton_full_reco_file( proton_train_clf, @@ -984,6 +1018,54 @@ def irf_events_table(): return ev +@pytest.fixture(scope="session") +def dummy_cuts_file( + gamma_diffuse_full_reco_file, + proton_full_reco_file, + event_loader_config_path, + irf_tmp_path, +): + from ctapipe.tools.optimize_event_selection import EventSelectionOptimizer + + output_path = irf_tmp_path / "dummy_cuts.fits" + run_tool( + EventSelectionOptimizer(), + argv=[ + f"--gamma-file={gamma_diffuse_full_reco_file}", + f"--proton-file={proton_full_reco_file}", + f"--electron-file={gamma_diffuse_full_reco_file}", + f"--output={output_path}", + f"--config={event_loader_config_path}", + ], + ) + return output_path + + +@pytest.fixture(scope="session") +def dummy_irf_file( + gamma_diffuse_full_reco_file, + proton_full_reco_file, + dummy_cuts_file, + event_loader_config_path, + irf_tmp_path, +): + from ctapipe.tools.compute_irf import IrfTool + + output_path = irf_tmp_path / "dummy_irf.fits" + run_tool( + IrfTool(), + argv=[ + f"--cuts={dummy_cuts_file}", + f"--gamma-file={gamma_diffuse_full_reco_file}", + f"--proton-file={proton_full_reco_file}", + f"--electron-file={gamma_diffuse_full_reco_file}", + f"--output={output_path}", + f"--config={event_loader_config_path}", + ], + ) + return output_path + + @pytest.fixture(scope="function") def test_config(): return { diff --git a/src/ctapipe/io/dl3.py b/src/ctapipe/io/dl3.py index eedba19617f..77a391b59e4 100644 --- a/src/ctapipe/io/dl3.py +++ b/src/ctapipe/io/dl3.py @@ -424,137 +424,6 @@ def __init__(self, **kwargs): self.file_creation_time = datetime.now(tz=UTC) self._reference_time = self.reference_time.tai - def _to_tai_time(self, value: Any, value_name: str) -> Time: - """ - Normalize input to an absolute TAI ``Time`` object. - - Parameters - ---------- - value : Any - Input time-like value. Supported types are ``Time``, ``TimeDelta``, - time ``Quantity`` and scalar numeric values interpreted as seconds - relative to ``reference_time``. - value_name : str - Name of the value used in error messages. - """ - if isinstance(value, Time): - return value.tai - - if isinstance(value, TimeDelta): - return self._reference_time + value - - if isinstance(value, u.Quantity): - if not value.unit.is_equivalent(u.s): - raise ValueError( - f"{value_name} must be a time quantity equivalent to seconds." - ) - return self._reference_time + TimeDelta(value.to(u.s)) - - if np.isscalar(value) and np.isreal(value): - return self._reference_time + TimeDelta(float(value) * u.s) - - raise TypeError( - f"{value_name} must be Time, TimeDelta, a time Quantity, or a scalar number of seconds." - ) - - def _to_relative_time_seconds(self, value: Any, value_name: str) -> Any: - """ - Normalize input to seconds relative to ``reference_time``. - - Parameters - ---------- - value : Any - Input time-like value. Supported types are ``Time``, ``TimeDelta``, - time ``Quantity`` and numeric values assumed to already be in seconds. - value_name : str - Name of the value used in error messages. - """ - if isinstance(value, Time): - return (value.tai - self._reference_time).to_value(u.s) - - if isinstance(value, TimeDelta): - return value.to_value(u.s) - - if isinstance(value, u.Quantity): - if not value.unit.is_equivalent(u.s): - raise ValueError( - f"{value_name} must be a time quantity equivalent to seconds." - ) - return value.to_value(u.s) - - values = np.asarray(value) - if np.issubdtype(values.dtype, np.number): - return values.astype(np.float64, copy=False) - - raise TypeError( - f"{value_name} must be Time, TimeDelta, a time Quantity, or numeric seconds." - ) - - def _to_relative_time_quantity(self, value: Any, value_name: str) -> u.Quantity: - """Normalize input to a quantity in seconds relative to ``reference_time``.""" - return u.Quantity( - self._to_relative_time_seconds(value, value_name), - u.s, - copy=COPY_IF_NEEDED, - ) - - @staticmethod - def _circular_interp(x, xp, fp_deg): - """ - Interpolate angular values in degrees, handling the 0/360 wrap-around. - - Uses ``np.unwrap`` to remove discontinuities before interpolation, - then wraps the result back to [0, 360). - - Parameters - ---------- - x : array-like - The x-coordinates at which to evaluate the interpolated values. - xp : array-like - The x-coordinates of the data points (must be increasing). - fp_deg : array-like - The y-coordinates of the data points, in degrees. - - Returns - ------- - np.ndarray - Interpolated angular values in degrees, in [0, 360). - """ - fp_rad = np.deg2rad(np.asarray(fp_deg, dtype=float)) - fp_unwrapped = np.unwrap(fp_rad) - result_rad = np.interp(x, xp, fp_unwrapped) - return np.rad2deg(result_rad) % 360 - - @staticmethod - def _circular_mean(angles_deg): - """ - Compute the mean of angular values in degrees, handling the 0/360 - wrap-around. - - Uses the ``atan2(mean(sin), mean(cos))`` formula for circular - statistics. - - Parameters - ---------- - angles_deg : array-like - Angular values in degrees. - - Returns - ------- - float - Mean angle in degrees, in [0, 360). - """ - angles_rad = np.deg2rad(np.asarray(angles_deg, dtype=float)) - return float( - np.rad2deg( - np.arctan2( - np.mean(np.sin(angles_rad)), - np.mean(np.cos(angles_rad)), - ) - ) - % 360 - ) - def write_file(self, path): """ This function will write the new DL3 file @@ -1003,3 +872,134 @@ def create_pointing_table(self) -> QTable: table = QTable(table_structure) table.sort("TIME") return table + + def _to_tai_time(self, value: Any, value_name: str) -> Time: + """ + Normalize input to an absolute TAI ``Time`` object. + + Parameters + ---------- + value : Any + Input time-like value. Supported types are ``Time``, ``TimeDelta``, + time ``Quantity`` and scalar numeric values interpreted as seconds + relative to ``reference_time``. + value_name : str + Name of the value used in error messages. + """ + if isinstance(value, Time): + return value.tai + + if isinstance(value, TimeDelta): + return self._reference_time + value + + if isinstance(value, u.Quantity): + if not value.unit.is_equivalent(u.s): + raise ValueError( + f"{value_name} must be a time quantity equivalent to seconds." + ) + return self._reference_time + TimeDelta(value.to(u.s)) + + if np.isscalar(value) and np.isreal(value): + return self._reference_time + TimeDelta(float(value) * u.s) + + raise TypeError( + f"{value_name} must be Time, TimeDelta, a time Quantity, or a scalar number of seconds." + ) + + def _to_relative_time_seconds(self, value: Any, value_name: str) -> Any: + """ + Normalize input to seconds relative to ``reference_time``. + + Parameters + ---------- + value : Any + Input time-like value. Supported types are ``Time``, ``TimeDelta``, + time ``Quantity`` and numeric values assumed to already be in seconds. + value_name : str + Name of the value used in error messages. + """ + if isinstance(value, Time): + return (value.tai - self._reference_time).to_value(u.s) + + if isinstance(value, TimeDelta): + return value.to_value(u.s) + + if isinstance(value, u.Quantity): + if not value.unit.is_equivalent(u.s): + raise ValueError( + f"{value_name} must be a time quantity equivalent to seconds." + ) + return value.to_value(u.s) + + values = np.asarray(value) + if np.issubdtype(values.dtype, np.number): + return values.astype(np.float64, copy=False) + + raise TypeError( + f"{value_name} must be Time, TimeDelta, a time Quantity, or numeric seconds." + ) + + def _to_relative_time_quantity(self, value: Any, value_name: str) -> u.Quantity: + """Normalize input to a quantity in seconds relative to ``reference_time``.""" + return u.Quantity( + self._to_relative_time_seconds(value, value_name), + u.s, + copy=COPY_IF_NEEDED, + ) + + @staticmethod + def _circular_interp(x, xp, fp_deg): + """ + Interpolate angular values in degrees, handling the 0/360 wrap-around. + + Uses ``np.unwrap`` to remove discontinuities before interpolation, + then wraps the result back to [0, 360). + + Parameters + ---------- + x : array-like + The x-coordinates at which to evaluate the interpolated values. + xp : array-like + The x-coordinates of the data points (must be increasing). + fp_deg : array-like + The y-coordinates of the data points, in degrees. + + Returns + ------- + np.ndarray + Interpolated angular values in degrees, in [0, 360). + """ + fp_rad = np.deg2rad(np.asarray(fp_deg, dtype=float)) + fp_unwrapped = np.unwrap(fp_rad) + result_rad = np.interp(x, xp, fp_unwrapped) + return np.rad2deg(result_rad) % 360 + + @staticmethod + def _circular_mean(angles_deg): + """ + Compute the mean of angular values in degrees, handling the 0/360 + wrap-around. + + Uses the ``atan2(mean(sin), mean(cos))`` formula for circular + statistics. + + Parameters + ---------- + angles_deg : array-like + Angular values in degrees. + + Returns + ------- + float + Mean angle in degrees, in [0, 360). + """ + angles_rad = np.deg2rad(np.asarray(angles_deg, dtype=float)) + return float( + np.rad2deg( + np.arctan2( + np.mean(np.sin(angles_rad)), + np.mean(np.cos(angles_rad)), + ) + ) + % 360 + ) diff --git a/src/ctapipe/io/tests/test_dl3.py b/src/ctapipe/io/tests/test_dl3.py index b727f80aa90..57cd1dff142 100644 --- a/src/ctapipe/io/tests/test_dl3.py +++ b/src/ctapipe/io/tests/test_dl3.py @@ -3,13 +3,17 @@ import astropy.units as u import numpy as np import pytest +from astropy.coordinates import ICRS, AltAz, SkyCoord from astropy.io import fits -from astropy.table import Column, Table -from astropy.time import Time -from traitlets.config import Config - -from ...irf.cuts import DL2EventSelection -from ...tools.create_dl3 import DL3Tool +from astropy.table import Column, QTable, Table, vstack +from astropy.time import Time, TimeDelta + +from ...containers import PointingMode +from ...core import QualityQuery +from ...io import TableLoader +from ...io.astropy_helpers import join_allow_empty +from ...io.dl2_tables_preprocessing import DL2EventPreprocessor +from ...version import version as ctapipe_version from ..dl3 import DL3GADFEventsWriter @@ -20,62 +24,132 @@ def hdu_irfs(dummy_irf_file): @pytest.fixture(scope="session") -def dl2_events_for_dl3(single_obs_gamma_diffuse_full_reco_file, dummy_cuts_file): - tool = DL3Tool( - config=Config( - { - "EventPreprocessor": { - "energy_reconstructor": "ExtraTreesRegressor", - "geometry_reconstructor": "HillasReconstructor", - "gammaness_reconstructor": "ExtraTreesClassifier", - } - } +def dl2_meta_for_dl3(single_obs_gamma_diffuse_full_reco_file): + with TableLoader( + single_obs_gamma_diffuse_full_reco_file, + dl2=True, + observation_info=True, + simulated=False, + ) as loader: + meta = { + "location": loader.subarray.reference_location, + "telescope_information": { + "organisation": "CTAO", + "array": "CTAO-North", + "subarray": "4LST", + "telescope_list": np.array( + loader.subarray.get_tel_ids(loader.subarray.tel) + ), + }, + "target": { + "observer": "SuperObserver", + "object_name": "Crab", + "object_coordinate": SkyCoord( + ra=83.6331 * u.deg, dec=22.0145 * u.deg, frame="icrs" + ), + }, + "software_version": { + "analysis_version": "ctapipe " + ctapipe_version, + "calibration_version": "UNKNOWN", + "dst_version": "UNKNOWN", + }, + "livetime_fraction": 0.97, + } + + obs_info = loader.read_observation_information() + sched_info = loader.read_scheduling_blocks() + obs_all_info = join_allow_empty(obs_info, sched_info, "sb_id", "inner") + row = obs_all_info[0] + meta["obs_id"] = int(row["obs_id"]) + + start_time = Time(row["actual_start_time"]).tai + stop_time = start_time + TimeDelta(obs_all_info["actual_duration"].quantity[0]) + meta["gti"] = [(start_time, stop_time)] + + pointing = AltAz( + alt=obs_all_info["subarray_pointing_lat"].quantity[0], + az=obs_all_info["subarray_pointing_lon"].quantity[0], + location=meta["location"], + obstime=start_time, ) + meta["pointing"] = { + "pointing_mode": PointingMode(row["pointing_mode"]).name, + "pointing_list": [(start_time, pointing), (stop_time, pointing)], + } + return meta + + +@pytest.fixture(scope="session") +def dl2_events_for_dl3(single_obs_gamma_diffuse_full_reco_file, dl2_meta_for_dl3): + preprocessor = DL2EventPreprocessor( + energy_reconstructor="ExtraTreesRegressor", + geometry_reconstructor="HillasReconstructor", + gammaness_classifier="ExtraTreesClassifier", + apply_derived_columns=False, + allow_unsupported_pointing_frames=True, + output_table_schema=[ + Column(name="obs_id", dtype=np.uint64), + Column(name="event_id", dtype=np.uint64), + Column(name="reco_energy", unit=u.TeV), + Column(name="reco_az", unit=u.deg), + Column(name="reco_alt", unit=u.deg), + Column(name="pointing_az", unit=u.deg), + Column(name="pointing_alt", unit=u.deg), + Column(name="gh_score", dtype=np.float64), + ], ) - tool.dl2_file = single_obs_gamma_diffuse_full_reco_file - tool.event_selection = DL2EventSelection(parent=tool, cuts_file=dummy_cuts_file) - tool._configure_event_preprocessor_for_dl3() - - events = tool._load_preselected_events(chunk_size=1000) - events = tool.event_selection.calculate_gamma_selection( - events, - apply_spatial_selection=False, + preprocessor.quality_query = QualityQuery( + parent=preprocessor, + quality_criteria=[ + ( + "multiplicity 4", + "np.count_nonzero(HillasReconstructor_telescopes,axis=1) >= 4", + ), + ("valid classifier", "ExtraTreesClassifier_is_valid"), + ("valid geom reco", "HillasReconstructor_is_valid"), + ("valid energy reco", "ExtraTreesRegressor_is_valid"), + ], ) - events = events[events["selected_gamma"]] - meta = tool._get_observation_information() + chunks = [] + with TableLoader( + single_obs_gamma_diffuse_full_reco_file, + dl2=True, + simulated=False, + observation_info=True, + ) as loader: + reader = loader.read_subarray_events_chunked( + 1000, + dl2=True, + simulated=False, + observation_info=True, + ) + for _, _, events in reader: + selected = events[preprocessor.quality_query.get_table_mask(events)] + if len(selected) == 0: + continue + chunks.append(preprocessor.normalise_column_names(selected)) + + if len(chunks) == 0: + raise ValueError("No events available for DL3 writer tests") + + events = QTable(vstack(chunks, join_type="exact", metadata_conflicts="silent")) events["time"] = ( Time("2020-01-01T00:00:00", scale="tai") + np.arange(len(events)) * u.s ) - tool.output_table_schema = [ - Column(name="reco_az", unit=u.deg), - Column(name="reco_alt", unit=u.deg), - Column(name="pointing_az", unit=u.deg), - Column(name="pointing_alt", unit=u.deg), - Column(name="time"), - Column(name="reco_ra", unit=u.deg), - Column(name="reco_dec", unit=u.deg), - ] - events = tool._make_derived_columns(events, location_subarray=meta["location"]) - return events - -@pytest.fixture(scope="session") -def dl2_meta_for_dl3(single_obs_gamma_diffuse_full_reco_file, dummy_cuts_file): - tool = DL3Tool( - config=Config( - { - "EventPreprocessor": { - "energy_reconstructor": "ExtraTreesRegressor", - "geometry_reconstructor": "HillasReconstructor", - "gammaness_reconstructor": "ExtraTreesClassifier", - } - } - ) + reco = SkyCoord( + alt=events["reco_alt"], + az=events["reco_az"], + frame=AltAz( + location=dl2_meta_for_dl3["location"], + obstime=Time(events["time"]), + ), ) - tool.dl2_file = single_obs_gamma_diffuse_full_reco_file - tool.event_selection = DL2EventSelection(parent=tool, cuts_file=dummy_cuts_file) - return tool._get_observation_information() + reco_icrs = reco.transform_to(ICRS()) + events["reco_ra"] = reco_icrs.ra.to(u.deg) + events["reco_dec"] = reco_icrs.dec.to(u.deg) + return events @pytest.fixture From 5a358964e6da808aa1057fe3e677b39aef6b9dcd Mon Sep 17 00:00:00 2001 From: Mathieu de Bony Date: Thu, 2 Apr 2026 12:57:25 +0200 Subject: [PATCH 3/7] Prepare for the PR --- docs/changes/2979.feature.rst | 1 + src/ctapipe/io/dl3.py | 5 +---- 2 files changed, 2 insertions(+), 4 deletions(-) create mode 100644 docs/changes/2979.feature.rst diff --git a/docs/changes/2979.feature.rst b/docs/changes/2979.feature.rst new file mode 100644 index 00000000000..85cd5593145 --- /dev/null +++ b/docs/changes/2979.feature.rst @@ -0,0 +1 @@ +Add class to write DL3 files diff --git a/src/ctapipe/io/dl3.py b/src/ctapipe/io/dl3.py index 77a391b59e4..b65a84c999c 100644 --- a/src/ctapipe/io/dl3.py +++ b/src/ctapipe/io/dl3.py @@ -1,7 +1,6 @@ from abc import abstractmethod from collections.abc import Mapping from datetime import UTC, datetime -from functools import lru_cache from typing import Any, Dict, List, Tuple import astropy.units as u @@ -507,7 +506,6 @@ def get_hdu_header_time_reference(self) -> Dict[str, Any]: "TIMESYS": "TAI", } - @lru_cache(maxsize=1) def get_hdu_header_base_time(self) -> Dict[str, Any]: """ Return the information about time parameters used in several HDU @@ -609,7 +607,6 @@ def get_hdu_header_base_software_information(self) -> Dict[str, Any]: header["CAL_VER"] = self.software_information["calibration_version"] return header - @lru_cache(maxsize=1) def get_hdu_header_base_pointing(self) -> Dict[str, Any]: """ Return information on the pointing during the observation @@ -697,7 +694,7 @@ def get_hdu_header_events(self) -> Dict[str, Any]: The output dictionary contain all the necessary information that should be added to the header of the events HDU """ header = self.get_hdu_header_base_format() - header.update({"HDUCLAS1": "EVENTS"}) + header.update({"HDUCLAS1": "EVENTS", "FOVALIGN": "ALTAZ"}) header.update(self.get_hdu_header_base_time()) header.update(self.get_hdu_header_base_pointing()) header.update(self.get_hdu_header_base_observation_information()) From a63cc55c623d958cc6607e786233a77f0b18a00e Mon Sep 17 00:00:00 2001 From: Mathieu de Bony Date: Thu, 2 Apr 2026 16:26:45 +0200 Subject: [PATCH 4/7] Fix test --- src/ctapipe/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ctapipe/conftest.py b/src/ctapipe/conftest.py index ce80c6bead2..5542cff7440 100644 --- a/src/ctapipe/conftest.py +++ b/src/ctapipe/conftest.py @@ -1027,7 +1027,7 @@ def dummy_cuts_file( ): from ctapipe.tools.optimize_event_selection import EventSelectionOptimizer - output_path = irf_tmp_path / "dummy_cuts.fits" + output_path = irf_tmp_path / "test_dummy_cuts.fits" run_tool( EventSelectionOptimizer(), argv=[ From 40f6a0b1efa07c765ad66ee8f43449701a971810 Mon Sep 17 00:00:00 2001 From: Mathieu de Bony Date: Thu, 9 Apr 2026 18:56:32 +0200 Subject: [PATCH 5/7] Add skip for pyirf on fixtures --- src/ctapipe/conftest.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/ctapipe/conftest.py b/src/ctapipe/conftest.py index 5542cff7440..a2fb770967a 100644 --- a/src/ctapipe/conftest.py +++ b/src/ctapipe/conftest.py @@ -879,6 +879,8 @@ def single_obs_gamma_diffuse_full_reco_file(gamma_diffuse_full_reco_file, irf_tm The full multi-observation file cannot be used for DL3 production, which requires a single obs_id per output file. """ + pytest.importorskip("pyirf", reason="pyirf is an optional dependency") + output_path = irf_tmp_path / "gamma_diffuse_single_obs.dl2.h5" shutil.copy(gamma_diffuse_full_reco_file, output_path) @@ -1025,6 +1027,7 @@ def dummy_cuts_file( event_loader_config_path, irf_tmp_path, ): + pytest.importorskip("pyirf", reason="pyirf is an optional dependency") from ctapipe.tools.optimize_event_selection import EventSelectionOptimizer output_path = irf_tmp_path / "test_dummy_cuts.fits" @@ -1049,6 +1052,7 @@ def dummy_irf_file( event_loader_config_path, irf_tmp_path, ): + pytest.importorskip("pyirf", reason="pyirf is an optional dependency") from ctapipe.tools.compute_irf import IrfTool output_path = irf_tmp_path / "dummy_irf.fits" From 9744378c2d195df66ee5de62d8c9bc908761ecbe Mon Sep 17 00:00:00 2001 From: Mathieu de Bony Date: Thu, 21 May 2026 15:13:05 +0200 Subject: [PATCH 6/7] Replace custom function for circular mean by the astropy one --- src/ctapipe/io/dl3.py | 40 +++++++------------------------- src/ctapipe/io/tests/test_dl3.py | 2 -- 2 files changed, 8 insertions(+), 34 deletions(-) diff --git a/src/ctapipe/io/dl3.py b/src/ctapipe/io/dl3.py index b65a84c999c..8a9f7660883 100644 --- a/src/ctapipe/io/dl3.py +++ b/src/ctapipe/io/dl3.py @@ -8,6 +8,7 @@ from astropy.coordinates import ( ICRS, AltAz, + Angle, BaseCoordinateFrame, EarthLocation, SkyCoord, @@ -15,6 +16,7 @@ from astropy.io import fits from astropy.io.fits import Header from astropy.io.fits.hdu.base import ExtensionHDU +from astropy.stats import circmean from astropy.table import QTable, Table from astropy.time import Time, TimeDelta @@ -676,10 +678,14 @@ def get_hdu_header_base_pointing(self) -> Dict[str, Any]: "RADECSYS": "ICRS", "EQUINOX": 2000.0, "OBS_MODE": obs_mode, - "RA_PNT": self._circular_mean(icrs_coordinate.ra.to_value(u.deg)), + "RA_PNT": Angle(circmean(icrs_coordinate.ra)) + .wrap_at(360 * u.deg) + .to_value(u.deg), "DEC_PNT": np.mean(icrs_coordinate.dec.to_value(u.deg)), "ALT_PNT": np.mean(altaz_coordinate.alt.to_value(u.deg)), - "AZ_PNT": self._circular_mean(altaz_coordinate.az.to_value(u.deg)), + "AZ_PNT": Angle(circmean(altaz_coordinate.az)) + .wrap_at(360 * u.deg) + .to_value(u.deg), "GEOLON": self.location.lon.to_value(u.deg), "GEOLAT": self.location.lat.to_value(u.deg), "ALTITUDE": self.location.height.to_value(u.m), @@ -970,33 +976,3 @@ def _circular_interp(x, xp, fp_deg): fp_unwrapped = np.unwrap(fp_rad) result_rad = np.interp(x, xp, fp_unwrapped) return np.rad2deg(result_rad) % 360 - - @staticmethod - def _circular_mean(angles_deg): - """ - Compute the mean of angular values in degrees, handling the 0/360 - wrap-around. - - Uses the ``atan2(mean(sin), mean(cos))`` formula for circular - statistics. - - Parameters - ---------- - angles_deg : array-like - Angular values in degrees. - - Returns - ------- - float - Mean angle in degrees, in [0, 360). - """ - angles_rad = np.deg2rad(np.asarray(angles_deg, dtype=float)) - return float( - np.rad2deg( - np.arctan2( - np.mean(np.sin(angles_rad)), - np.mean(np.cos(angles_rad)), - ) - ) - % 360 - ) diff --git a/src/ctapipe/io/tests/test_dl3.py b/src/ctapipe/io/tests/test_dl3.py index 57cd1dff142..98402f9390a 100644 --- a/src/ctapipe/io/tests/test_dl3.py +++ b/src/ctapipe/io/tests/test_dl3.py @@ -495,7 +495,6 @@ def test_hdu_header_pointing(self, dl3_writer, dl2_meta_for_dl3): assert header["OBSGEO-Z"] == pytest.approx(loc.z.to_value(u.m)) def test_hdu_header_pointing_track_mode_regression(self, dl3_writer): - """Regression: TRACK mode must not fail when interpolating RA around 0/360.""" dl3_writer.pointing_mode = "TRACK" header = dl3_writer.get_hdu_header_base_pointing() @@ -504,7 +503,6 @@ def test_hdu_header_pointing_track_mode_regression(self, dl3_writer): assert np.isfinite(header[key]) def test_hdu_header_pointing_drift_mode_regression(self, dl3_writer): - """Regression: DRIFT mode interpolation must provide valid angular quantities.""" dl3_writer.pointing_mode = "DRIFT" header = dl3_writer.get_hdu_header_base_pointing() From 7c5ef0a94adc1cd0aa01750cb9cc5e368d900b2c Mon Sep 17 00:00:00 2001 From: Mathieu de Bony Date: Thu, 28 May 2026 18:29:16 +0200 Subject: [PATCH 7/7] Add a data class for DL3 information to be written --- src/ctapipe/io/dl3.py | 1095 +++++++++++++++++++----------- src/ctapipe/io/tests/test_dl3.py | 384 ++++++----- 2 files changed, 919 insertions(+), 560 deletions(-) diff --git a/src/ctapipe/io/dl3.py b/src/ctapipe/io/dl3.py index 8a9f7660883..4fd87a759b0 100644 --- a/src/ctapipe/io/dl3.py +++ b/src/ctapipe/io/dl3.py @@ -1,5 +1,6 @@ from abc import abstractmethod from collections.abc import Mapping +from dataclasses import dataclass, fields from datetime import UTC, datetime from typing import Any, Dict, List, Tuple @@ -17,7 +18,7 @@ from astropy.io.fits import Header from astropy.io.fits.hdu.base import ExtensionHDU from astropy.stats import circmean -from astropy.table import QTable, Table +from astropy.table import QTable from astropy.time import Time, TimeDelta from ..compat import COPY_IF_NEEDED @@ -25,394 +26,511 @@ from ..core.traits import AstroTime, Bool from ..version import version as ctapipe_version -__all__ = ["DL3EventsWriter", "DL3GADFEventsWriter"] +__all__ = ["DL3EventsData", "DL3EventsWriter", "DL3GADFEventsWriter"] -class DL3EventsWriter(Component): +@dataclass(slots=True) +class DL3EventsData: """ - Base class for writing a DL3 file + The class contain all information required to generate DL3 file + + Parameters + ---------- + events : QTable + A table with a line for each event and column for each of the parameters required for the DL3 creation. + obs_id : int + Observation ID. + pointing : list[tuple[Time, SkyCoord]] + A list with for each entry containing the time at which the coordinate where evaluated and the associated coordinates. + pointing_mode : str + The name of the pointing mode used for the observation. Must be ``TRACK`` or ``DRIFT``. + gti : list[tuple[Time, Time]] + A list with for each entry containing the start and stop time of the good time intervals. + livetime_fraction : float + The livetime fraction for the observation. + location : EarthLocation + The location of the telescope. + telescope_information : dict[str, any] + A dictionary containing general information about telescope with as key: organisation, array, subarray, telescope_list. + aeff : ExtensionHDU + The effective area HDU read from the fits file containing IRFs. + psf : ExtensionHDU + The PSF HDU read from the fits file containing IRFs. + edisp : ExtensionHDU + The EDISP HDU read from the fits file containing IRFs. + bkg : ExtensionHDU, optional + The background HDU read from the fits file containing IRFs. + target_information : dict[str, any], optional + A dictionary containing general information about the targeted source with as key: observer, object_name, object_coordinate. + software_information : dict[str, any], optional + A dictionary containing general information about the software used to produce the file with as key: analysis_version, calibration_version, dst_version. """ - overwrite = Bool( - default_value=False, - help="If true, allow to overwrite already existing output file", - ).tag(config=True) - - optional_dl3_columns = Bool( - default_value=False, help="If true add optional columns to produce file" - ).tag(config=True) + events: QTable + obs_id: int + pointing: List[Tuple[Time, SkyCoord]] + pointing_mode: str + gti: List[Tuple[Time, Time]] + livetime_fraction: float + location: EarthLocation + telescope_information: Dict[str, Any] + aeff: ExtensionHDU + psf: ExtensionHDU + edisp: ExtensionHDU + bkg: ExtensionHDU | None = None + target_information: Dict[str, Any] | None = None + software_information: Dict[str, Any] | None = None + + def __setattr__(self, name: str, value: Any): + """ + Set a DL3 payload field after validating its value. - raise_error_for_optional = Bool( - default_value=True, - help="If true will raise error in the case optional column are missing", - ).tag(config=True) + Parameters + ---------- + name : str + Name of the field to set. + value : any + New value for the field. + """ + object.__setattr__(self, name, self._validate_field(name, value)) - reference_time = AstroTime( - default_value=Time("2018-01-01T00:00:00", scale="tai"), - help="The reference time that will be used in the FITS file", - ).tag(config=True) + def __post_init__(self): + """ + Validate and normalize all DL3 payload fields after construction. + """ + for field in fields(self): + object.__setattr__( + self, + field.name, + self._validate_field(field.name, getattr(self, field.name)), + ) - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._obs_id = None - self._events = None - self._pointing = None - self._pointing_mode = None - self._gti = None - self._aeff = None - self._psf = None - self._edisp = None - self._bkg = None - self._livetime_fraction = None - self._location = None - self._telescope_information = None - self._target_information = None - self._software_information = None + def _validate_field(self, name: str, value: Any) -> Any: + """ + Validate and normalize a DL3 payload field. - @abstractmethod - def write_file(self, path): - pass + Parameters + ---------- + name : str + Name of the field to validate. + value : any + Value to validate. - @property - def obs_id(self) -> int: - return self._obs_id + Returns + ------- + any + The validated and normalized value. + """ + validator = getattr(self, f"_validate_{name}", None) + if validator is not None: + value = validator(value) + return value - @obs_id.setter - def obs_id(self, obs_id: int): + @staticmethod + def _validate_obs_id(obs_id: int) -> int: """ + Validate observation ID. + Parameters ---------- obs_id : int - Observation ID - """ - if self._obs_id is not None: - self.log.warning( - "Obs id for DL3 file was already set, replacing current obs id" - ) - if obs_id is not None: - if not isinstance(obs_id, (int, np.integer)) or isinstance(obs_id, bool): - raise TypeError("obs_id must be an integer.") - if obs_id < 0: - raise ValueError("obs_id must be >= 0.") - self._obs_id = obs_id + Observation ID. - @property - def events(self) -> QTable: - return self._events + Returns + ------- + int + Observation ID cast to a Python ``int``. + """ + if obs_id is None: + raise ValueError("obs_id is required.") + if not isinstance(obs_id, (int, np.integer)) or isinstance(obs_id, bool): + raise TypeError("obs_id must be an integer.") + if obs_id < 0: + raise ValueError("obs_id must be >= 0") + return int(obs_id) - @events.setter - def events(self, events: QTable): + @staticmethod + def _validate_events(events: QTable) -> QTable: """ + Validate the events table. + Parameters ---------- events : QTable - A table with a line for each event - """ - if self._events is not None: - self.log.warning( - "Events table for DL3 file was already set, replacing current event table" - ) - if events is not None and not isinstance(events, (QTable, Table)): - raise TypeError("events must be an astropy Table or QTable.") - self._events = events + A table with a line for each event. - @property - def pointing(self) -> List[Tuple[Time, SkyCoord]]: - return self._pointing + Returns + ------- + QTable + The validated events table. + """ + if events is None: + raise ValueError("events is required.") + if not isinstance(events, QTable): + raise TypeError("events must be an astropy QTable.") + return events - @pointing.setter - def pointing(self, pointing: List[Tuple[Time, SkyCoord]]): + @staticmethod + def _validate_pointing( + pointing: list[tuple[Time, SkyCoord]], + ) -> list[tuple[Time, SkyCoord]]: """ + Validate the pointing information. + Parameters ---------- - pointing : List[Tuple[Time, SkyCoord]] - A list with for each entry containing the time at which the coordinate where evaluated and the associated coordinates + pointing : list[tuple[Time, SkyCoord]] + A list with for each entry containing the time at which the + coordinate where evaluated and the associated coordinates. + + Returns + ------- + list[tuple[Time, SkyCoord]] + The validated pointing information. """ - if self._pointing is not None: - self.log.warning( - "Pointing for DL3 file was already set, replacing current pointing" - ) - if pointing is not None: - if not isinstance(pointing, (list, tuple)): - raise TypeError("pointing must be a list of (time, coordinate) pairs.") - - for i, value in enumerate(pointing): - if not isinstance(value, (list, tuple)) or len(value) != 2: - raise ValueError( - f"pointing[{i}] must be a (time, coordinate) pair." - ) + if pointing is None: + raise ValueError("pointing is required.") + if not isinstance(pointing, (list, tuple)): + raise TypeError("pointing must be a list of (time, coordinate) pairs.") - coordinate = value[1] - if not isinstance(coordinate, (SkyCoord, BaseCoordinateFrame)): - raise TypeError( - f"pointing[{i}].coordinate must be a SkyCoord or coordinate frame." - ) - self._pointing = pointing + for i, value in enumerate(pointing): + if not isinstance(value, (list, tuple)) or len(value) != 2: + raise ValueError(f"pointing[{i}] must be a (time, coordinate) pair.") - @property - def pointing_mode(self) -> str: - return self._pointing_mode + coordinate = value[1] + if not isinstance(coordinate, (SkyCoord, BaseCoordinateFrame)): + raise TypeError( + f"pointing[{i}].coordinate must be a SkyCoord or coordinate frame." + ) + return pointing - @pointing_mode.setter - def pointing_mode(self, pointing_mode: str): + @staticmethod + def _validate_pointing_mode(pointing_mode: str) -> str: """ + Validate and normalize the pointing mode. + Parameters ---------- pointing_mode : str - The name of the pointing mode used for the observation - """ - if self._pointing_mode is not None: - self.log.warning( - "Pointing for DL3 file was already set, replacing current pointing" - ) - if pointing_mode is not None: - if not isinstance(pointing_mode, str): - raise TypeError("pointing_mode must be a string.") + The name of the pointing mode used for the observation. - pointing_mode = pointing_mode.strip().upper() - if pointing_mode not in {"TRACK", "DRIFT"}: - raise ValueError("pointing_mode must be either 'TRACK' or 'DRIFT'.") - self._pointing_mode = pointing_mode + Returns + ------- + str + Pointing mode normalized to ``TRACK`` or ``DRIFT``. + """ + if pointing_mode is None: + raise ValueError("pointing_mode is required.") + if not isinstance(pointing_mode, str): + raise TypeError("pointing_mode must be a string.") - @property - def gti(self) -> List[Tuple[Time, Time]]: - return self._gti + pointing_mode = pointing_mode.strip().upper() + if pointing_mode not in {"TRACK", "DRIFT"}: + raise ValueError("pointing_mode must be either 'TRACK' or 'DRIFT'.") + return pointing_mode - @gti.setter - def gti(self, gti: List[Tuple[Time, Time]]): + @staticmethod + def _validate_gti(gti: list[tuple[Time, Time]]) -> list[tuple[Time, Time]]: """ + Validate the good time intervals. + Parameters ---------- - gti : List[Tuple[Time, Time]] - A list with for each entry containing the time the start and stop time of the good time intervals - """ - if self._gti is not None: - self.log.warning("GTI for DL3 file was already set, replacing current gti") - if gti is not None: - if not isinstance(gti, (list, tuple)): - raise TypeError("gti must be a list of (start, stop) pairs.") + gti : list[tuple[Time, Time]] + A list with for each entry containing the start and stop time of + the good time intervals. - for i, value in enumerate(gti): - if not isinstance(value, (list, tuple)) or len(value) != 2: - raise ValueError(f"gti[{i}] must be a (start, stop) pair.") - self._gti = gti + Returns + ------- + list[tuple[Time, Time]] + The validated good time intervals. + """ + if gti is None: + raise ValueError("gti is required.") + if not isinstance(gti, (list, tuple)): + raise TypeError("gti must be a list of (start, stop) pairs.") - @property - def aeff(self) -> ExtensionHDU: - return self._aeff + for i, value in enumerate(gti): + if not isinstance(value, (list, tuple)) or len(value) != 2: + raise ValueError(f"gti[{i}] must be a (start, stop) pair.") + return gti - @aeff.setter - def aeff(self, aeff: ExtensionHDU): + @staticmethod + def _validate_livetime_fraction(livetime_fraction: float) -> float: """ + Validate the livetime fraction. + Parameters ---------- - aeff : ExtensionHDU - The effective area HDU read from the fits file containing IRFs - """ - if self._aeff is not None: - self.log.warning( - "Effective area for DL3 file was already set, replacing current effective area" - ) - if aeff is not None and not isinstance(aeff, ExtensionHDU): - raise TypeError("aeff must be a FITS ExtensionHDU.") - self._aeff = aeff + livetime_fraction : float + The livetime fraction for the observations (DEADC correction + factor). - @property - def psf(self) -> ExtensionHDU: - return self._psf + Returns + ------- + float + The validated livetime fraction. + """ + if livetime_fraction is None: + raise ValueError("livetime_fraction is required.") + if isinstance(livetime_fraction, (bool, np.bool_)) or ( + not np.isscalar(livetime_fraction) or not np.isreal(livetime_fraction) + ): + raise TypeError("livetime_fraction must be a real scalar.") + if not np.isfinite(livetime_fraction) or (not 0.0 <= livetime_fraction <= 1.0): + raise ValueError("livetime_fraction must be in the range [0, 1].") + return livetime_fraction - @psf.setter - def psf(self, psf: ExtensionHDU): + @staticmethod + def _validate_location(location: EarthLocation) -> EarthLocation: """ + Validate the telescope location. + Parameters ---------- - psf : ExtensionHDU - The PSF HDU read from the fits file containing IRFs - """ - if self._psf is not None: - self.log.warning("PSF for DL3 file was already set, replacing current PSF") - if psf is not None and not isinstance(psf, ExtensionHDU): - raise TypeError("psf must be a FITS ExtensionHDU.") - self._psf = psf + location : EarthLocation + The location of the telescope. - @property - def edisp(self) -> ExtensionHDU: - return self._edisp + Returns + ------- + EarthLocation + The validated telescope location. + """ + if location is None: + raise ValueError("location is required.") + if not isinstance(location, EarthLocation): + raise TypeError("location must be an astropy EarthLocation.") + return location - @edisp.setter - def edisp(self, edisp: ExtensionHDU): + @staticmethod + def _validate_telescope_information( + telescope_information: dict[str, Any], + ) -> dict[str, Any]: """ + Validate the telescope information. + Parameters ---------- - edisp : ExtensionHDU - The EDISP HDU read from the fits file containing IRFs + telescope_information : dict[str, any] + A dictionary containing general information about telescope with as + key: organisation, array, subarray, telescope_list. + + Returns + ------- + dict[str, any] + The validated telescope information. """ - if self._edisp is not None: - self.log.warning( - "EDISP for DL3 file was already set, replacing current EDISP" + if telescope_information is None: + raise ValueError("telescope_information is required.") + if not isinstance(telescope_information, Mapping): + raise TypeError("telescope_information must be a mapping.") + required = {"organisation", "array", "subarray", "telescope_list"} + missing = required - set(telescope_information) + if missing: + raise ValueError( + "telescope_information is missing keys: " + ", ".join(sorted(missing)) ) - if edisp is not None and not isinstance(edisp, ExtensionHDU): - raise TypeError("edisp must be a FITS ExtensionHDU.") - self._edisp = edisp - - @property - def bkg(self) -> ExtensionHDU: - return self._bkg + return telescope_information - @bkg.setter - def bkg(self, bkg: ExtensionHDU): + @staticmethod + def _validate_target_information( + target_information: dict[str, Any], + ) -> dict[str, Any]: """ + Validate the target information. + Parameters ---------- - bkg : ExtensionHDU - The background HDU read from the fits file containing IRFs + target_information : dict[str, any] + A dictionary containing general information about the targeted + source with as key: observer, object_name, object_coordinate. + + Returns + ------- + dict[str, any] or None + The validated target information, or ``None`` if omitted. """ - if self._bkg is not None: - self.log.warning( - "Background for DL3 file was already set, replacing current background" + if target_information is None: + return None + if not isinstance(target_information, Mapping): + raise TypeError("target_information must be a mapping.") + required = {"observer", "object_name", "object_coordinate"} + missing = required - set(target_information) + if missing: + raise ValueError( + "target_information is missing keys: " + ", ".join(sorted(missing)) ) - if bkg is not None and not isinstance(bkg, ExtensionHDU): - raise TypeError("bkg must be a FITS ExtensionHDU.") - self._bkg = bkg - @property - def location(self) -> EarthLocation: - return self._location + coordinate = target_information["object_coordinate"] + if not isinstance(coordinate, SkyCoord): + raise TypeError( + "target_information['object_coordinate'] must be a SkyCoord" + ) + return target_information - @location.setter - def location(self, location: EarthLocation): + @staticmethod + def _validate_software_information( + software_information: dict[str, Any], + ) -> dict[str, Any]: """ + Validate the software information. + Parameters ---------- - location : EarthLocation - The location of the telescope + software_information : dict[str, any] + A dictionary containing general information about the software used + to produce the file with as key: analysis_version, + calibration_version, dst_version. + + Returns + ------- + dict[str, any] or None + The validated software information, or ``None`` if omitted. """ - if self._location is not None: - self.log.warning( - "Telescope location for DL3 file was already set, replacing current location" + if software_information is None: + return None + if not isinstance(software_information, Mapping): + raise TypeError("software_information must be a mapping.") + required = {"analysis_version", "calibration_version", "dst_version"} + missing = required - set(software_information) + if missing: + raise ValueError( + "software_information is missing keys: " + ", ".join(sorted(missing)) ) - if location is not None and not isinstance(location, EarthLocation): - raise TypeError("location must be an astropy EarthLocation.") - self._location = location + return software_information - @property - def livetime_fraction(self) -> float: - return self._livetime_fraction - - @livetime_fraction.setter - def livetime_fraction(self, livetime_fraction: float): + @staticmethod + def _validate_irf(value: ExtensionHDU, name: str) -> ExtensionHDU: """ + Validate a required IRF HDU. + Parameters ---------- - livetime_fraction : float - The livetime fraction for the observations (DEADC correction factor) - """ - if self.livetime_fraction is not None: - self.log.warning( - "Livetime fraction for DL3 file was already set, replacing current livetime fraction" - ) + value : ExtensionHDU + IRF HDU read from the fits file containing IRFs. + name : str + Name of the IRF field used in error messages. - if livetime_fraction is None: - self._livetime_fraction = None - return + Returns + ------- + ExtensionHDU + The validated IRF HDU. + """ + if value is None: + raise ValueError(f"{name} is required.") + if not isinstance(value, ExtensionHDU): + raise TypeError(f"{name} must be a FITS ExtensionHDU.") + return value - if isinstance(livetime_fraction, (bool, np.bool_)) or ( - not np.isscalar(livetime_fraction) or not np.isreal(livetime_fraction) - ): - raise TypeError("livetime_fraction must be a real scalar.") - if not np.isfinite(livetime_fraction) or (not 0.0 <= livetime_fraction <= 1.0): - raise ValueError("livetime_fraction must be in the range [0, 1].") + def _validate_aeff(self, aeff: ExtensionHDU) -> ExtensionHDU: + """ + Validate the effective area HDU. - self._livetime_fraction = livetime_fraction + Parameters + ---------- + aeff : ExtensionHDU + The effective area HDU read from the fits file containing IRFs. - @property - def telescope_information(self) -> Dict[str, Any]: - return self._telescope_information + Returns + ------- + ExtensionHDU + The validated effective area HDU. + """ + return self._validate_irf(aeff, "aeff") - @telescope_information.setter - def telescope_information(self, telescope_information: Dict[str, Any]): + def _validate_psf(self, psf: ExtensionHDU) -> ExtensionHDU: """ + Validate the PSF HDU. + Parameters ---------- - telescope_information : dict[str, any] - A dictionary containing general information about telescope with as key : organisation, array, subarray, telescope_list + psf : ExtensionHDU + The PSF HDU read from the fits file containing IRFs. + + Returns + ------- + ExtensionHDU + The validated PSF HDU. """ - if self._telescope_information is not None: - self.log.warning( - "Telescope information for DL3 file was already set, replacing current information" - ) - if telescope_information is not None: - if not isinstance(telescope_information, Mapping): - raise TypeError("telescope_information must be a mapping.") - required = {"organisation", "array", "subarray", "telescope_list"} - missing = required - set(telescope_information) - if missing: - raise ValueError( - "telescope_information is missing keys: " - + ", ".join(sorted(missing)) - ) - self._telescope_information = telescope_information + return self._validate_irf(psf, "psf") - @property - def target_information(self) -> Dict[str, Any]: - return self._target_information + def _validate_edisp(self, edisp: ExtensionHDU) -> ExtensionHDU: + """ + Validate the EDISP HDU. - @target_information.setter - def target_information(self, target_information: Dict[str, Any]): + Parameters + ---------- + edisp : ExtensionHDU + The EDISP HDU read from the fits file containing IRFs. + + Returns + ------- + ExtensionHDU + The validated EDISP HDU. """ + return self._validate_irf(edisp, "edisp") + + @staticmethod + def _validate_bkg(bkg: ExtensionHDU) -> ExtensionHDU: + """ + Validate the background HDU. + Parameters ---------- - target_information : dict[str, any] - A dictionary containing general information about the targeted source with as key : observer, object_name, object_coordinate + bkg : ExtensionHDU + The background HDU read from the fits file containing IRFs. + + Returns + ------- + ExtensionHDU or None + The validated background HDU, or ``None`` if omitted. """ - if self._target_information is not None: - self.log.warning( - "Target information for DL3 file was already set, replacing current target information" - ) - if target_information is not None: - if not isinstance(target_information, Mapping): - raise TypeError("target_information must be a mapping.") - required = {"observer", "object_name", "object_coordinate"} - missing = required - set(target_information) - if missing: - raise ValueError( - "target_information is missing keys: " + ", ".join(sorted(missing)) - ) + if bkg is not None and not isinstance(bkg, ExtensionHDU): + raise TypeError("bkg must be a FITS ExtensionHDU.") + return bkg - coordinate = target_information["object_coordinate"] - if not isinstance(coordinate, (SkyCoord, BaseCoordinateFrame)): - raise TypeError( - "target_information['object_coordinate'] must be a SkyCoord or coordinate frame." - ) - self._target_information = target_information - @property - def software_information(self) -> Dict[str, Any]: - return self._software_information +class DL3EventsWriter(Component): + """ + Base class for writing a DL3 file + """ + + overwrite = Bool( + default_value=False, + help="If true, allow to overwrite already existing output file", + ).tag(config=True) + + optional_dl3_columns = Bool( + default_value=False, help="If true add optional columns to produce file" + ).tag(config=True) + + raise_error_for_optional = Bool( + default_value=True, + help="If true will raise error in the case optional column are missing", + ).tag(config=True) - @software_information.setter - def software_information(self, software_information: Dict[str, Any]): + reference_time = AstroTime( + default_value=Time("2018-01-01T00:00:00", scale="tai"), + help="The reference time that will be used in the FITS file", + ).tag(config=True) + + @abstractmethod + def write_file(self, path: str, data: DL3EventsData): """ + This function will write the new DL3 file. + Parameters ---------- - software_information : dict[str, any] - A dictionary containing general information about the software used to produce the file with as key : analysis_version, calibration_version, dst_version + path : str + The full path and filename of the new file to write. + data : DL3EventsData + The DL3 file payload to write. + + Returns + ------- + None """ - if self._software_information is not None: - self.log.warning( - "Software information for DL3 file was already set, replacing current software information" - ) - if software_information is not None: - if not isinstance(software_information, Mapping): - raise TypeError("software_information must be a mapping.") - required = {"analysis_version", "calibration_version", "dst_version"} - missing = required - set(software_information) - if missing: - raise ValueError( - "software_information is missing keys: " - + ", ".join(sorted(missing)) - ) - self._software_information = software_information + pass class DL3GADFEventsWriter(DL3EventsWriter): @@ -422,74 +540,98 @@ class DL3GADFEventsWriter(DL3EventsWriter): def __init__(self, **kwargs): super().__init__(**kwargs) - self.file_creation_time = datetime.now(tz=UTC) self._reference_time = self.reference_time.tai - def write_file(self, path): + def write_file(self, path: str, data: DL3EventsData): """ - This function will write the new DL3 file - All the content associated with the file should have been specified previously, otherwise error will be raised + This function will write the new DL3 file. Parameters ---------- path : str - The full path and filename of the new file to write + The full path and filename of the new file to write. + data : DL3EventsData + The DL3 file payload to write. + + Returns + ------- + None """ - self.file_creation_time = datetime.now(tz=UTC) + if not isinstance(data, DL3EventsData): + raise TypeError("data must be a DL3EventsData instance.") + creation_time = datetime.now(tz=UTC) hdu_dl3 = fits.HDUList( - [fits.PrimaryHDU(header=Header(self.get_hdu_header_base_format()))] + [ + fits.PrimaryHDU( + header=Header(self.get_hdu_header_base_format(creation_time)) + ) + ] ) hdu_dl3.append( fits.BinTableHDU( - data=self.transform_events_columns_for_gadf_format(self.events), + data=self.transform_events_columns_for_gadf_format(data.events), name="EVENTS", - header=Header(self.get_hdu_header_events()), + header=Header(self.get_hdu_header_events(data, creation_time)), ) ) hdu_dl3.append( fits.BinTableHDU( - data=self.create_gti_table(), + data=self.create_gti_table(data), name="GTI", - header=Header(self.get_hdu_header_gti()), + header=Header(self.get_hdu_header_gti(data, creation_time)), ) ) hdu_dl3.append( fits.BinTableHDU( - data=self.create_pointing_table(), + data=self.create_pointing_table(data), name="POINTING", - header=Header(self.get_hdu_header_pointing()), + header=Header(self.get_hdu_header_pointing(data, creation_time)), ) ) - if self.aeff is None: + if data.aeff is None: raise ValueError("Missing effective area IRF") - hdu_dl3.append(self.aeff) - hdu_dl3[-1].header["OBS_ID"] = self.obs_id - if self.psf is None: + if data.psf is None: raise ValueError("Missing PSF IRF") - hdu_dl3.append(self.psf) - hdu_dl3[-1].header["OBS_ID"] = self.obs_id - if self.edisp is None: + if data.edisp is None: raise ValueError("Missing EDISP IRF") - hdu_dl3.append(self.edisp) - hdu_dl3[-1].header["OBS_ID"] = self.obs_id - if self.bkg is not None: - hdu_dl3.append(self.bkg) - hdu_dl3[-1].header["OBS_ID"] = self.obs_id + + for irf in (data.aeff, data.psf, data.edisp): + output_hdu = irf.copy() + output_hdu.header["OBS_ID"] = data.obs_id + hdu_dl3.append(output_hdu) + + if data.bkg is not None: + output_hdu = data.bkg.copy() + output_hdu.header["OBS_ID"] = data.obs_id + hdu_dl3.append(output_hdu) hdu_dl3.writeto(path, checksum=True, overwrite=self.overwrite) - def get_hdu_header_base_format(self) -> Dict[str, Any]: + def get_hdu_header_base_format( + self, creation_time: datetime | None = None + ) -> Dict[str, Any]: """ - Return the base information that should be included in all HDU of the final fits file + Return the base information that should be included in all HDU of the final fits file. + + Parameters + ---------- + creation_time : datetime, optional + The file creation time to write into the header. If omitted, the + current UTC time is used. + + Returns + ------- + dict[str, any] + Header keywords common to all HDUs in the DL3 file. """ return { "HDUCLASS": "GADF", "HDUVERS": "v0.3", "HDUDOC": "https://gamma-astro-data-formats.readthedocs.io/en/v0.3/index.html", "CREATOR": "ctapipe " + ctapipe_version, - "CREATED": self.file_creation_time.isoformat(), + "CREATED": (creation_time or datetime.now(tz=UTC)).isoformat(), } def get_hdu_header_time_reference(self) -> Dict[str, Any]: @@ -499,6 +641,11 @@ def get_hdu_header_time_reference(self) -> Dict[str, Any]: These keywords (MJDREFI, MJDREFF, TIMEUNIT, TIMESYS, TIMEREF) should be present in every HDU that contains a TIME column or time-related header values. + + Returns + ------- + dict[str, any] + Header keywords defining the FITS time reference. """ return { "MJDREFI": int(self._reference_time.mjd), @@ -508,19 +655,29 @@ def get_hdu_header_time_reference(self) -> Dict[str, Any]: "TIMESYS": "TAI", } - def get_hdu_header_base_time(self) -> Dict[str, Any]: + def get_hdu_header_base_time(self, data: DL3EventsData) -> Dict[str, Any]: """ - Return the information about time parameters used in several HDU + Return the information about time parameters used in several HDU. + + Parameters + ---------- + data : DL3EventsData + The DL3 file payload containing the GTI and livetime fraction. + + Returns + ------- + dict[str, any] + Header keywords describing the observation time range and livetime. """ - if self.gti is None: + if data.gti is None: raise ValueError("No available time information for the DL3 file") - if self.livetime_fraction is None: + if data.livetime_fraction is None: raise ValueError("No available livetime fraction for the DL3 file") start_time = None stop_time = None ontime = TimeDelta(0.0 * u.s) - for i, gti_interval in enumerate(self.gti): + for i, gti_interval in enumerate(data.gti): interval_start = self._to_tai_time(gti_interval[0], f"gti[{i}].start") interval_stop = self._to_tai_time(gti_interval[1], f"gti[{i}].stop") if interval_stop < interval_start: @@ -546,8 +703,8 @@ def get_hdu_header_base_time(self) -> Dict[str, Any]: ), "TSTOP": self._to_relative_time_seconds(stop_time, "observation stop"), "ONTIME": ontime.to_value(u.s), - "LIVETIME": ontime.to_value(u.s) * self.livetime_fraction, - "DEADC": self.livetime_fraction, + "LIVETIME": ontime.to_value(u.s) * data.livetime_fraction, + "DEADC": data.livetime_fraction, "TELAPSE": (stop_time - start_time).to_value(u.s), "DATE-OBS": start_time.fits, "DATE-BEG": start_time.fits, @@ -558,23 +715,32 @@ def get_hdu_header_base_time(self) -> Dict[str, Any]: return header def get_hdu_header_base_observation_information( - self, obs_id_only: bool = False + self, data: DL3EventsData, obs_id_only: bool = False ) -> Dict[str, Any]: """ - Return generic information on the observation setting (id, target, ...) + Return generic information on the observation setting (id, target, ...). Parameters ---------- + data : DL3EventsData + The DL3 file payload containing the observation and target + information. obs_id_only : bool - If true, will return a dict with as only information the obs_id + If true, will return a dict with as only information the obs_id. + + Returns + ------- + dict[str, any] + Header keywords describing the observation and, if requested, the + target information. """ - if self.obs_id is None: + if data.obs_id is None: raise ValueError("Observation ID is missing.") - header = {"OBS_ID": self.obs_id} - if self.target_information is not None and not obs_id_only: - header["OBSERVER"] = self.target_information["observer"] - header["OBJECT"] = self.target_information["object_name"] - object_coordinate = self.target_information[ + header = {"OBS_ID": data.obs_id} + if data.target_information is not None and not obs_id_only: + header["OBSERVER"] = data.target_information["observer"] + header["OBJECT"] = data.target_information["object_name"] + object_coordinate = data.target_information[ "object_coordinate" ].transform_to(ICRS()) if not np.isnan(object_coordinate.ra.to_value(u.deg)): @@ -583,44 +749,81 @@ def get_hdu_header_base_observation_information( header["DEC_OBJ"] = object_coordinate.dec.to_value(u.deg) return header - def get_hdu_header_base_subarray_information(self) -> Dict[str, Any]: + def get_hdu_header_base_subarray_information( + self, data: DL3EventsData + ) -> Dict[str, Any]: """ - Return generic information on the array used for observations + Return generic information on the array used for observations. + + Parameters + ---------- + data : DL3EventsData + The DL3 file payload containing the telescope information. + + Returns + ------- + dict[str, any] + Header keywords describing the array and telescope list. """ - if self.telescope_information is None: + if data.telescope_information is None: raise ValueError("Telescope information are missing.") header = { - "ORIGIN": self.telescope_information["organisation"], - "TELESCOP": self.telescope_information["array"], - "INSTRUME": self.telescope_information["subarray"], - "TELLIST": str(self.telescope_information["telescope_list"]), - "N_TELS": len(self.telescope_information["telescope_list"]), + "ORIGIN": data.telescope_information["organisation"], + "TELESCOP": data.telescope_information["array"], + "INSTRUME": data.telescope_information["subarray"], + "TELLIST": str(data.telescope_information["telescope_list"]), + "N_TELS": len(data.telescope_information["telescope_list"]), } return header - def get_hdu_header_base_software_information(self) -> Dict[str, Any]: + def get_hdu_header_base_software_information( + self, data: DL3EventsData + ) -> Dict[str, Any]: """ - Return information about the software versions used to process the observation + Return information about the software versions used to process the observation. + + Parameters + ---------- + data : DL3EventsData + The DL3 file payload containing the software information. + + Returns + ------- + dict[str, any] + Header keywords describing software versions used to process the + observation. """ header = {} - if self.software_information is not None: - header["DST_VER"] = self.software_information["dst_version"] - header["ANA_VER"] = self.software_information["analysis_version"] - header["CAL_VER"] = self.software_information["calibration_version"] + if data.software_information is not None: + header["DST_VER"] = data.software_information["dst_version"] + header["ANA_VER"] = data.software_information["analysis_version"] + header["CAL_VER"] = data.software_information["calibration_version"] return header - def get_hdu_header_base_pointing(self) -> Dict[str, Any]: + def get_hdu_header_base_pointing(self, data: DL3EventsData) -> Dict[str, Any]: """ - Return information on the pointing during the observation + Return information on the pointing during the observation. + + Parameters + ---------- + data : DL3EventsData + The DL3 file payload containing pointing, pointing mode, GTI and + telescope location information. + + Returns + ------- + dict[str, any] + Header keywords describing the observation pointing and telescope + location. """ - if self.pointing is None: + if data.pointing is None: raise ValueError("Pointing information are missing") - if self.pointing_mode is None: + if data.pointing_mode is None: raise ValueError("Pointing mode is missing") - if self.location is None: + if data.location is None: raise ValueError("Telescope location information are missing") - gti_table = self.create_gti_table() + gti_table = self.create_gti_table(data) delta_time_evaluation = [] for i in range(len(gti_table)): delta_time_evaluation += list( @@ -629,8 +832,8 @@ def get_hdu_header_base_pointing(self) -> Dict[str, Any]: delta_time_evaluation = u.Quantity(delta_time_evaluation) time_evaluation = self._reference_time + TimeDelta(delta_time_evaluation) - pointing_table = self.create_pointing_table() - if self.pointing_mode == "TRACK": + pointing_table = self.create_pointing_table(data) + if data.pointing_mode == "TRACK": obs_mode = "POINTING" icrs_coordinate = SkyCoord( ra=self._circular_interp( @@ -646,9 +849,9 @@ def get_hdu_header_base_pointing(self) -> Dict[str, Any]: unit=u.deg, ) altaz_coordinate = icrs_coordinate.transform_to( - AltAz(location=self.location, obstime=time_evaluation) + AltAz(location=data.location, obstime=time_evaluation) ) - elif self.pointing_mode == "DRIFT": + elif data.pointing_mode == "DRIFT": obs_mode = "DRIFT" altaz_coordinate = AltAz( alt=u.Quantity( @@ -666,7 +869,7 @@ def get_hdu_header_base_pointing(self) -> Dict[str, Any]: fp_deg=pointing_table["AZ_PNT"], ) * u.deg, - location=self.location, + location=data.location, obstime=time_evaluation, ) icrs_coordinate = altaz_coordinate.transform_to(ICRS()) @@ -686,50 +889,95 @@ def get_hdu_header_base_pointing(self) -> Dict[str, Any]: "AZ_PNT": Angle(circmean(altaz_coordinate.az)) .wrap_at(360 * u.deg) .to_value(u.deg), - "GEOLON": self.location.lon.to_value(u.deg), - "GEOLAT": self.location.lat.to_value(u.deg), - "ALTITUDE": self.location.height.to_value(u.m), - "OBSGEO-X": self.location.x.to_value(u.m), - "OBSGEO-Y": self.location.y.to_value(u.m), - "OBSGEO-Z": self.location.z.to_value(u.m), + "GEOLON": data.location.lon.to_value(u.deg), + "GEOLAT": data.location.lat.to_value(u.deg), + "ALTITUDE": data.location.height.to_value(u.m), + "OBSGEO-X": data.location.x.to_value(u.m), + "OBSGEO-Y": data.location.y.to_value(u.m), + "OBSGEO-Z": data.location.z.to_value(u.m), } return header - def get_hdu_header_events(self) -> Dict[str, Any]: + def get_hdu_header_events( + self, data: DL3EventsData, creation_time=None + ) -> Dict[str, Any]: """ - The output dictionary contain all the necessary information that should be added to the header of the events HDU + Return all the necessary information that should be added to the header of the events HDU. + + Parameters + ---------- + data : DL3EventsData + The DL3 file payload to use for the header. + creation_time : datetime, optional + The file creation time to write into the header. If omitted, the + current UTC time is used. + + Returns + ------- + dict[str, any] + Header keywords for the EVENTS HDU. """ - header = self.get_hdu_header_base_format() + header = self.get_hdu_header_base_format(creation_time) header.update({"HDUCLAS1": "EVENTS", "FOVALIGN": "ALTAZ"}) - header.update(self.get_hdu_header_base_time()) - header.update(self.get_hdu_header_base_pointing()) - header.update(self.get_hdu_header_base_observation_information()) - header.update(self.get_hdu_header_base_subarray_information()) - header.update(self.get_hdu_header_base_software_information()) + header.update(self.get_hdu_header_base_time(data)) + header.update(self.get_hdu_header_base_pointing(data)) + header.update(self.get_hdu_header_base_observation_information(data)) + header.update(self.get_hdu_header_base_subarray_information(data)) + header.update(self.get_hdu_header_base_software_information(data)) return header - def get_hdu_header_gti(self) -> Dict[str, Any]: + def get_hdu_header_gti( + self, data: DL3EventsData, creation_time=None + ) -> Dict[str, Any]: """ - The output dictionary contain all the necessary information that should be added to the header of the GTI HDU + Return all the necessary information that should be added to the header of the GTI HDU. + + Parameters + ---------- + data : DL3EventsData + The DL3 file payload to use for the header. + creation_time : datetime, optional + The file creation time to write into the header. If omitted, the + current UTC time is used. + + Returns + ------- + dict[str, any] + Header keywords for the GTI HDU. """ - header = self.get_hdu_header_base_format() + header = self.get_hdu_header_base_format(creation_time) header.update({"HDUCLAS1": "GTI"}) - header.update(self.get_hdu_header_base_time()) + header.update(self.get_hdu_header_base_time(data)) header.update( - self.get_hdu_header_base_observation_information(obs_id_only=True) + self.get_hdu_header_base_observation_information(data, obs_id_only=True) ) return header - def get_hdu_header_pointing(self) -> Dict[str, Any]: + def get_hdu_header_pointing( + self, data: DL3EventsData, creation_time=None + ) -> Dict[str, Any]: """ - The output dictionary contain all the necessary information that should be added to the header of the pointing HDU + Return all the necessary information that should be added to the header of the pointing HDU. + + Parameters + ---------- + data : DL3EventsData + The DL3 file payload to use for the header. + creation_time : datetime, optional + The file creation time to write into the header. If omitted, the + current UTC time is used. + + Returns + ------- + dict[str, any] + Header keywords for the POINTING HDU. """ - header = self.get_hdu_header_base_format() + header = self.get_hdu_header_base_format(creation_time) header.update({"HDUCLAS1": "POINTING"}) header.update(self.get_hdu_header_time_reference()) - header.update(self.get_hdu_header_base_pointing()) + header.update(self.get_hdu_header_base_pointing(data)) header.update( - self.get_hdu_header_base_observation_information(obs_id_only=True) + self.get_hdu_header_base_observation_information(data, obs_id_only=True) ) return header @@ -741,7 +989,12 @@ def transform_events_columns_for_gadf_format(self, events: QTable) -> QTable: Parameters ---------- events : QTable - The base events table to process + The base events table to process. + + Returns + ------- + QTable + Event table containing the DL3/GADF columns with GADF names. """ rename_from = ["event_id", "time", "reco_ra", "reco_dec", "reco_energy"] rename_to = ["EVENT_ID", "TIME", "RA", "DEC", "ENERGY"] @@ -817,12 +1070,23 @@ def transform_events_columns_for_gadf_format(self, events: QTable) -> QTable: renamed_events = renamed_events[rename_to] return renamed_events - def create_gti_table(self) -> QTable: + def create_gti_table(self, data: DL3EventsData) -> QTable: """ - Build a table that contains GTI information with the GADF names and format, to be concerted directly as a TableHDU + Build a table that contains GTI information with the GADF names and format, to be concerted directly as a TableHDU. + + Parameters + ---------- + data : DL3EventsData + The DL3 file payload containing the good time intervals. + + Returns + ------- + QTable + GTI table with START and STOP columns in seconds relative to the + writer reference time. """ table_structure = {"START": [], "STOP": []} - for i, gti_interval in enumerate(self.gti): + for i, gti_interval in enumerate(data.gti): interval_start = self._to_tai_time(gti_interval[0], f"gti[{i}].start") interval_stop = self._to_tai_time(gti_interval[1], f"gti[{i}].stop") table_structure["START"].append( @@ -841,13 +1105,25 @@ def create_gti_table(self) -> QTable: return table - def create_pointing_table(self) -> QTable: + def create_pointing_table(self, data: DL3EventsData) -> QTable: """ - Build a table that contains pointing information with the GADF names and format, to be concerted directly as a TableHDU + Build a table that contains pointing information with the GADF names and format, to be concerted directly as a TableHDU. + + Parameters + ---------- + data : DL3EventsData + The DL3 file payload containing pointing and telescope location + information. + + Returns + ------- + QTable + Pointing table with TIME, RA_PNT, DEC_PNT, ALT_PNT and AZ_PNT + columns in GADF format. """ - if self.pointing is None: + if data.pointing is None: raise ValueError("Pointing information are missing") - if self.location is None: + if data.location is None: raise ValueError("Telescope location information are missing") table_structure = { @@ -858,11 +1134,11 @@ def create_pointing_table(self) -> QTable: "AZ_PNT": [], } - for i, pointing in enumerate(self.pointing): + for i, pointing in enumerate(data.pointing): time = self._to_tai_time(pointing[0], f"pointing[{i}].time") pointing_icrs = pointing[1].transform_to(ICRS()) pointing_altaz = pointing[1].transform_to( - AltAz(location=self.location, obstime=time) + AltAz(location=data.location, obstime=time) ) table_structure["TIME"].append( self._to_relative_time_quantity(time, f"pointing[{i}].time") @@ -888,6 +1164,11 @@ def _to_tai_time(self, value: Any, value_name: str) -> Time: relative to ``reference_time``. value_name : str Name of the value used in error messages. + + Returns + ------- + Time + Input value converted to an absolute TAI time. """ if isinstance(value, Time): return value.tai @@ -920,6 +1201,12 @@ def _to_relative_time_seconds(self, value: Any, value_name: str) -> Any: time ``Quantity`` and numeric values assumed to already be in seconds. value_name : str Name of the value used in error messages. + + Returns + ------- + float or numpy.ndarray + Input value converted to seconds relative to the writer reference + time. """ if isinstance(value, Time): return (value.tai - self._reference_time).to_value(u.s) @@ -943,7 +1230,23 @@ def _to_relative_time_seconds(self, value: Any, value_name: str) -> Any: ) def _to_relative_time_quantity(self, value: Any, value_name: str) -> u.Quantity: - """Normalize input to a quantity in seconds relative to ``reference_time``.""" + """ + Normalize input to a quantity in seconds relative to ``reference_time``. + + Parameters + ---------- + value : Any + Input time-like value. Supported types are ``Time``, ``TimeDelta``, + time ``Quantity`` and numeric values assumed to already be in seconds. + value_name : str + Name of the value used in error messages. + + Returns + ------- + astropy.units.Quantity + Input value converted to seconds relative to the writer reference + time. + """ return u.Quantity( self._to_relative_time_seconds(value, value_name), u.s, diff --git a/src/ctapipe/io/tests/test_dl3.py b/src/ctapipe/io/tests/test_dl3.py index 98402f9390a..95abba0fc96 100644 --- a/src/ctapipe/io/tests/test_dl3.py +++ b/src/ctapipe/io/tests/test_dl3.py @@ -14,7 +14,7 @@ from ...io.astropy_helpers import join_allow_empty from ...io.dl2_tables_preprocessing import DL2EventPreprocessor from ...version import version as ctapipe_version -from ..dl3 import DL3GADFEventsWriter +from ..dl3 import DL3EventsData, DL3GADFEventsWriter @pytest.fixture @@ -153,50 +153,57 @@ def dl2_events_for_dl3(single_obs_gamma_diffuse_full_reco_file, dl2_meta_for_dl3 @pytest.fixture -def dl3_writer(dl2_events_for_dl3, dl2_meta_for_dl3, hdu_irfs): - dl3_format_optional = DL3GADFEventsWriter() - - # Load events - dl3_format_optional.events = dl2_events_for_dl3 - - # Load metadata - dl3_format_optional.obs_id = dl2_meta_for_dl3["obs_id"] - dl3_format_optional.pointing = dl2_meta_for_dl3["pointing"]["pointing_list"] - dl3_format_optional.pointing_mode = dl2_meta_for_dl3["pointing"]["pointing_mode"] - dl3_format_optional.gti = dl2_meta_for_dl3["gti"] - dl3_format_optional.livetime_fraction = dl2_meta_for_dl3["livetime_fraction"] - dl3_format_optional.location = dl2_meta_for_dl3["location"] - dl3_format_optional.telescope_information = dl2_meta_for_dl3[ - "telescope_information" - ] - dl3_format_optional.target_information = dl2_meta_for_dl3["target"] - dl3_format_optional.software_information = dl2_meta_for_dl3["software_version"] - - # Load IRFs +def dl3_data(dl2_events_for_dl3, dl2_meta_for_dl3, hdu_irfs): + aeff = None + psf = None + edisp = None + bkg = None + for i in range(1, len(hdu_irfs)): if "HDUCLAS2" in hdu_irfs[i].header.keys(): if hdu_irfs[i].header["HDUCLAS2"] == "EFF_AREA": - if dl3_format_optional.aeff is None: - dl3_format_optional.aeff = hdu_irfs[i] + if aeff is None: + aeff = hdu_irfs[i] elif "EXTNAME" in hdu_irfs[i].header and not ( "PROTONS" in hdu_irfs[i].header["EXTNAME"] or "ELECTRONS" in hdu_irfs[i].header["EXTNAME"] ): - dl3_format_optional.aeff = hdu_irfs[i] + aeff = hdu_irfs[i] elif hdu_irfs[i].header["HDUCLAS2"] == "EDISP": - dl3_format_optional.edisp = hdu_irfs[i] + edisp = hdu_irfs[i] elif hdu_irfs[i].header["HDUCLAS2"] == "PSF": - dl3_format_optional.psf = hdu_irfs[i] + psf = hdu_irfs[i] elif hdu_irfs[i].header["HDUCLAS2"] == "BKG": - dl3_format_optional.bkg = hdu_irfs[i] - return dl3_format_optional + bkg = hdu_irfs[i] + + return DL3EventsData( + events=dl2_events_for_dl3, + obs_id=dl2_meta_for_dl3["obs_id"], + pointing=dl2_meta_for_dl3["pointing"]["pointing_list"], + pointing_mode=dl2_meta_for_dl3["pointing"]["pointing_mode"], + gti=dl2_meta_for_dl3["gti"], + livetime_fraction=dl2_meta_for_dl3["livetime_fraction"], + location=dl2_meta_for_dl3["location"], + telescope_information=dl2_meta_for_dl3["telescope_information"], + target_information=dl2_meta_for_dl3["target"], + software_information=dl2_meta_for_dl3["software_version"], + aeff=aeff, + psf=psf, + edisp=edisp, + bkg=bkg, + ) + + +@pytest.fixture +def dl3_writer(): + return DL3GADFEventsWriter() class TestDL3GADFEventsWriter: - def test_dl3_file(self, tmp_path, dl3_writer): + def test_dl3_file(self, tmp_path, dl3_writer, dl3_data): output_path = tmp_path / "dl3_gadf.fits" - dl3_writer.write_file(output_path) + dl3_writer.write_file(output_path, dl3_data) with fits.open(output_path, checksum=True) as hdul: assert isinstance(hdul[0], fits.PrimaryHDU) @@ -215,37 +222,78 @@ def test_dl3_file(self, tmp_path, dl3_writer): for hdu in hdul: if "OBS_ID" in hdu.header: - assert hdu.header["OBS_ID"] == dl3_writer.obs_id + assert hdu.header["OBS_ID"] == dl3_data.obs_id - def test_dl3_file_missing_aeff(self, tmp_path, dl3_writer): + def test_dl3_file_missing_aeff(self, tmp_path, dl3_writer, dl3_data): output_path = tmp_path / "dl3_gadf_aeff.fits" - dl3_writer._aeff = None + object.__setattr__(dl3_data, "aeff", None) with pytest.raises(ValueError): - dl3_writer.write_file(output_path) + dl3_writer.write_file(output_path, dl3_data) - def test_dl3_file_missing_edisp(self, tmp_path, dl3_writer): + def test_dl3_file_missing_edisp(self, tmp_path, dl3_writer, dl3_data): output_path = tmp_path / "dl3_gadf_edisp.fits" - dl3_writer._edisp = None + object.__setattr__(dl3_data, "edisp", None) with pytest.raises(ValueError): - dl3_writer.write_file(output_path) + dl3_writer.write_file(output_path, dl3_data) - def test_dl3_file_missing_psf(self, tmp_path, dl3_writer): + def test_dl3_file_missing_psf(self, tmp_path, dl3_writer, dl3_data): output_path = tmp_path / "dl3_gadf_psf.fits" - dl3_writer._psf = None + object.__setattr__(dl3_data, "psf", None) with pytest.raises(ValueError): - dl3_writer.write_file(output_path) + dl3_writer.write_file(output_path, dl3_data) - def test_dl3_file_overwrite(self, tmp_path, dl3_writer): + def test_dl3_file_overwrite(self, tmp_path, dl3_writer, dl3_data): output_path = tmp_path / "dl3_gadf_overwrite.fits" - dl3_writer.write_file(output_path) + dl3_writer.write_file(output_path, dl3_data) with pytest.raises(OSError): - dl3_writer.write_file(output_path) + dl3_writer.write_file(output_path, dl3_data) + + def test_writer_reuse_does_not_leak_state(self, tmp_path, dl3_writer, dl3_data): + first_path = tmp_path / "dl3_gadf_first.fits" + second_path = tmp_path / "dl3_gadf_second.fits" + original_irf_obs_ids = [ + hdu.header.get("OBS_ID") + for hdu in (dl3_data.aeff, dl3_data.psf, dl3_data.edisp, dl3_data.bkg) + if hdu is not None + ] + + second_data = DL3EventsData( + events=dl3_data.events, + obs_id=dl3_data.obs_id + 1, + pointing=dl3_data.pointing, + pointing_mode=dl3_data.pointing_mode, + gti=dl3_data.gti, + livetime_fraction=dl3_data.livetime_fraction, + location=dl3_data.location, + telescope_information=dl3_data.telescope_information, + target_information=dl3_data.target_information, + software_information=dl3_data.software_information, + aeff=dl3_data.aeff, + psf=dl3_data.psf, + edisp=dl3_data.edisp, + bkg=dl3_data.bkg, + ) + + dl3_writer.write_file(first_path, dl3_data) + dl3_writer.write_file(second_path, second_data) + + with fits.open(first_path, checksum=True) as first_hdul: + assert first_hdul["EVENTS"].header["OBS_ID"] == dl3_data.obs_id + with fits.open(second_path, checksum=True) as second_hdul: + assert second_hdul["EVENTS"].header["OBS_ID"] == second_data.obs_id + + current_irf_obs_ids = [ + hdu.header.get("OBS_ID") + for hdu in (dl3_data.aeff, dl3_data.psf, dl3_data.edisp, dl3_data.bkg) + if hdu is not None + ] + assert current_irf_obs_ids == original_irf_obs_ids - def test_hdu_header_base(self, dl3_writer): + def test_hdu_header_base(self, dl3_writer, dl3_data): header = dl3_writer.get_hdu_header_base_format() assert header["HDUCLASS"] == "GADF" @@ -255,8 +303,8 @@ def test_hdu_header_base(self, dl3_writer): file_time = datetime.fromisoformat(header["CREATED"]) assert (datetime.now(UTC) - file_time) < timedelta(hours=1) - def test_hdu_header_time(self, dl3_writer): - header = dl3_writer.get_hdu_header_base_time() + def test_hdu_header_time(self, dl3_writer, dl3_data): + header = dl3_writer.get_hdu_header_base_time(dl3_data) for key in [ "MJDREFI", @@ -303,159 +351,167 @@ def test_hdu_header_time(self, dl3_writer): assert (tstop - tref).to_value(u.s) == pytest.approx(header["TSTOP"], rel=1e-6) assert (tavg >= tstart) & (tavg <= tstop) - def test_hdu_header_time_missing_gti(self, dl3_writer): - dl3_writer._gti = None + def test_hdu_header_time_missing_gti(self, dl3_writer, dl3_data): + object.__setattr__(dl3_data, "gti", None) with pytest.raises(ValueError): - dl3_writer.get_hdu_header_base_time() + dl3_writer.get_hdu_header_base_time(dl3_data) - def test_hdu_header_time_missing_deadtime(self, dl3_writer): - dl3_writer._livetime_fraction = None + def test_hdu_header_time_missing_deadtime(self, dl3_writer, dl3_data): + object.__setattr__(dl3_data, "livetime_fraction", None) with pytest.raises(ValueError): - dl3_writer.get_hdu_header_base_time() + dl3_writer.get_hdu_header_base_time(dl3_data) - def test_livetime_fraction_setter_validation(self, dl3_writer): - dl3_writer.livetime_fraction = 0.0 - assert dl3_writer.livetime_fraction == 0.0 + def test_livetime_fraction_setter_validation(self, dl3_writer, dl3_data): + dl3_data.livetime_fraction = 0.0 + assert dl3_data.livetime_fraction == 0.0 - dl3_writer.livetime_fraction = 1.0 - assert dl3_writer.livetime_fraction == 1.0 + dl3_data.livetime_fraction = 1.0 + assert dl3_data.livetime_fraction == 1.0 for invalid in (-1e-3, 1.001, np.nan, np.inf, -np.inf): with pytest.raises(ValueError): - dl3_writer.livetime_fraction = invalid + dl3_data.livetime_fraction = invalid for invalid in ([0.5], "0.5", True): with pytest.raises(TypeError): - dl3_writer.livetime_fraction = invalid + dl3_data.livetime_fraction = invalid - dl3_writer.livetime_fraction = None - assert dl3_writer.livetime_fraction is None + with pytest.raises(ValueError): + dl3_data.livetime_fraction = None - def test_obs_id_setter_validation(self, dl3_writer): - dl3_writer.obs_id = np.int64(1234) - assert dl3_writer.obs_id == 1234 + def test_obs_id_setter_validation(self, dl3_writer, dl3_data): + dl3_data.obs_id = np.int64(1234) + assert dl3_data.obs_id == 1234 with pytest.raises(ValueError): - dl3_writer.obs_id = -1 + dl3_data.obs_id = -1 for invalid in (1.2, "1", True): with pytest.raises(TypeError): - dl3_writer.obs_id = invalid + dl3_data.obs_id = invalid + + with pytest.raises(ValueError): + dl3_data.obs_id = None - dl3_writer.obs_id = None - assert dl3_writer.obs_id is None + def test_events_setter_validation(self, dl3_writer, dl3_data): + qtable = QTable(dl3_data.events, copy=True) + dl3_data.events = qtable + assert dl3_data.events is qtable - def test_events_setter_validation(self, dl3_writer): - table = Table(dl3_writer.events, copy=False) - dl3_writer.events = table - assert dl3_writer.events is table + table = Table(dl3_data.events, copy=False) + with pytest.raises(TypeError): + dl3_data.events = table with pytest.raises(TypeError): - dl3_writer.events = {"not": "a table"} + dl3_data.events = {"not": "a table"} - dl3_writer.events = None - assert dl3_writer.events is None + with pytest.raises(ValueError): + dl3_data.events = None - def test_pointing_setter_validation(self, dl3_writer): + def test_pointing_setter_validation(self, dl3_writer, dl3_data): with pytest.raises(TypeError): - dl3_writer.pointing = "not-a-sequence" + dl3_data.pointing = "not-a-sequence" with pytest.raises(ValueError): - dl3_writer.pointing = [(Time("2020-01-01T00:00:00", scale="tai"),)] + dl3_data.pointing = [(Time("2020-01-01T00:00:00", scale="tai"),)] with pytest.raises(TypeError): - dl3_writer.pointing = [(Time("2020-01-01T00:00:00", scale="tai"), object())] + dl3_data.pointing = [(Time("2020-01-01T00:00:00", scale="tai"), object())] - dl3_writer.pointing = None - assert dl3_writer.pointing is None + with pytest.raises(ValueError): + dl3_data.pointing = None - def test_pointing_mode_setter_validation(self, dl3_writer): - dl3_writer.pointing_mode = "track" - assert dl3_writer.pointing_mode == "TRACK" + def test_pointing_mode_setter_validation(self, dl3_writer, dl3_data): + dl3_data.pointing_mode = "track" + assert dl3_data.pointing_mode == "TRACK" - dl3_writer.pointing_mode = " drift " - assert dl3_writer.pointing_mode == "DRIFT" + dl3_data.pointing_mode = " drift " + assert dl3_data.pointing_mode == "DRIFT" with pytest.raises(TypeError): - dl3_writer.pointing_mode = 1 + dl3_data.pointing_mode = 1 with pytest.raises(ValueError): - dl3_writer.pointing_mode = "WOBBLE" + dl3_data.pointing_mode = "WOBBLE" - def test_gti_setter_validation(self, dl3_writer): + def test_gti_setter_validation(self, dl3_writer, dl3_data): with pytest.raises(TypeError): - dl3_writer.gti = "not-a-sequence" + dl3_data.gti = "not-a-sequence" with pytest.raises(ValueError): - dl3_writer.gti = [(Time("2020-01-01T00:00:00", scale="tai"),)] + dl3_data.gti = [(Time("2020-01-01T00:00:00", scale="tai"),)] - dl3_writer.gti = None - assert dl3_writer.gti is None + with pytest.raises(ValueError): + dl3_data.gti = None - def test_location_setter_validation(self, dl3_writer): + def test_location_setter_validation(self, dl3_writer, dl3_data): with pytest.raises(TypeError): - dl3_writer.location = "not-a-location" + dl3_data.location = "not-a-location" - dl3_writer.location = None - assert dl3_writer.location is None + with pytest.raises(ValueError): + dl3_data.location = None @pytest.mark.parametrize("setter", ["aeff", "psf", "edisp", "bkg"]) - def test_irf_setter_validation(self, dl3_writer, setter): + def test_irf_setter_validation(self, dl3_writer, dl3_data, setter): with pytest.raises(TypeError): - setattr(dl3_writer, setter, "not-an-hdu") + setattr(dl3_data, setter, "not-an-hdu") - def test_telescope_information_setter_validation(self, dl3_writer): + def test_telescope_information_setter_validation(self, dl3_writer, dl3_data): with pytest.raises(TypeError): - dl3_writer.telescope_information = "not-a-mapping" + dl3_data.telescope_information = "not-a-mapping" with pytest.raises(ValueError, match="missing keys"): - dl3_writer.telescope_information = {"organisation": "CTAO"} + dl3_data.telescope_information = {"organisation": "CTAO"} - def test_target_information_setter_validation(self, dl3_writer): + def test_target_information_setter_validation(self, dl3_writer, dl3_data): with pytest.raises(TypeError): - dl3_writer.target_information = "not-a-mapping" + dl3_data.target_information = "not-a-mapping" with pytest.raises(ValueError, match="missing keys"): - dl3_writer.target_information = {"observer": "UNKNOWN"} + dl3_data.target_information = {"observer": "UNKNOWN"} with pytest.raises(TypeError): - dl3_writer.target_information = { + dl3_data.target_information = { "observer": "UNKNOWN", "object_name": "UNKNOWN", "object_coordinate": object(), } - def test_software_information_setter_validation(self, dl3_writer): + def test_software_information_setter_validation(self, dl3_writer, dl3_data): with pytest.raises(TypeError): - dl3_writer.software_information = "not-a-mapping" + dl3_data.software_information = "not-a-mapping" with pytest.raises(ValueError, match="missing keys"): - dl3_writer.software_information = {"analysis_version": "ctapipe X"} + dl3_data.software_information = {"analysis_version": "ctapipe X"} - def test_hdu_header_obs_info(self, dl3_writer, dl2_meta_for_dl3): + def test_hdu_header_obs_info(self, dl3_writer, dl3_data, dl2_meta_for_dl3): obs_only = dl3_writer.get_hdu_header_base_observation_information( - obs_id_only=True + dl3_data, obs_id_only=True ) - assert obs_only["OBS_ID"] == dl3_writer.obs_id + assert obs_only["OBS_ID"] == dl3_data.obs_id assert len(obs_only) == 1 full_header = dl3_writer.get_hdu_header_base_observation_information( - obs_id_only=False + dl3_data, obs_id_only=False ) - assert full_header["OBS_ID"] == dl3_writer.obs_id + assert full_header["OBS_ID"] == dl3_data.obs_id target = dl2_meta_for_dl3["target"] assert full_header["OBSERVER"] == target["observer"] assert full_header["OBJECT"] == target["object_name"] - def test_hdu_header_obs_info_missing_obs_id(self, dl3_writer): - dl3_writer._obs_id = None + def test_hdu_header_obs_info_missing_obs_id(self, dl3_writer, dl3_data): + object.__setattr__(dl3_data, "obs_id", None) with pytest.raises(ValueError): - dl3_writer.get_hdu_header_base_observation_information(obs_id_only=True) + dl3_writer.get_hdu_header_base_observation_information( + dl3_data, obs_id_only=True + ) with pytest.raises(ValueError): - dl3_writer.get_hdu_header_base_observation_information(obs_id_only=False) + dl3_writer.get_hdu_header_base_observation_information( + dl3_data, obs_id_only=False + ) - def test_hdu_header_subarray_info(self, dl3_writer, dl2_meta_for_dl3): - header = dl3_writer.get_hdu_header_base_subarray_information() + def test_hdu_header_subarray_info(self, dl3_writer, dl3_data, dl2_meta_for_dl3): + header = dl3_writer.get_hdu_header_base_subarray_information(dl3_data) tel_info = dl2_meta_for_dl3["telescope_information"] assert header["ORIGIN"] == tel_info["organisation"] @@ -464,19 +520,19 @@ def test_hdu_header_subarray_info(self, dl3_writer, dl2_meta_for_dl3): assert header["TELLIST"] == str(tel_info["telescope_list"]) assert header["N_TELS"] == len(tel_info["telescope_list"]) - def test_hdu_header_software_info(self, dl3_writer, dl2_meta_for_dl3): - header = dl3_writer.get_hdu_header_base_software_information() + def test_hdu_header_software_info(self, dl3_writer, dl3_data, dl2_meta_for_dl3): + header = dl3_writer.get_hdu_header_base_software_information(dl3_data) soft = dl2_meta_for_dl3["software_version"] assert header["DST_VER"] == soft["dst_version"] assert header["ANA_VER"] == soft["analysis_version"] assert header["CAL_VER"] == soft["calibration_version"] - dl3_writer._software_information = None - header = dl3_writer.get_hdu_header_base_software_information() + object.__setattr__(dl3_data, "software_information", None) + header = dl3_writer.get_hdu_header_base_software_information(dl3_data) assert len(header) == 0 - def test_hdu_header_pointing(self, dl3_writer, dl2_meta_for_dl3): - header = dl3_writer.get_hdu_header_base_pointing() + def test_hdu_header_pointing(self, dl3_writer, dl3_data, dl2_meta_for_dl3): + header = dl3_writer.get_hdu_header_base_pointing(dl3_data) assert header["RADESYS"] == "ICRS" assert header["RADECSYS"] == "ICRS" @@ -494,39 +550,39 @@ def test_hdu_header_pointing(self, dl3_writer, dl2_meta_for_dl3): assert header["OBSGEO-Y"] == pytest.approx(loc.y.to_value(u.m)) assert header["OBSGEO-Z"] == pytest.approx(loc.z.to_value(u.m)) - def test_hdu_header_pointing_track_mode_regression(self, dl3_writer): - dl3_writer.pointing_mode = "TRACK" - header = dl3_writer.get_hdu_header_base_pointing() + def test_hdu_header_pointing_track_mode_regression(self, dl3_writer, dl3_data): + dl3_data.pointing_mode = "TRACK" + header = dl3_writer.get_hdu_header_base_pointing(dl3_data) assert header["OBS_MODE"] == "POINTING" for key in ["RA_PNT", "DEC_PNT", "ALT_PNT", "AZ_PNT"]: assert np.isfinite(header[key]) - def test_hdu_header_pointing_drift_mode_regression(self, dl3_writer): - dl3_writer.pointing_mode = "DRIFT" - header = dl3_writer.get_hdu_header_base_pointing() + def test_hdu_header_pointing_drift_mode_regression(self, dl3_writer, dl3_data): + dl3_data.pointing_mode = "DRIFT" + header = dl3_writer.get_hdu_header_base_pointing(dl3_data) assert header["OBS_MODE"] == "DRIFT" for key in ["RA_PNT", "DEC_PNT", "ALT_PNT", "AZ_PNT"]: assert np.isfinite(header[key]) - def test_hdu_header_pointing_missing_pointing(self, dl3_writer): - dl3_writer._pointing = None + def test_hdu_header_pointing_missing_pointing(self, dl3_writer, dl3_data): + object.__setattr__(dl3_data, "pointing", None) with pytest.raises(ValueError): - dl3_writer.get_hdu_header_base_pointing() + dl3_writer.get_hdu_header_base_pointing(dl3_data) - def test_hdu_header_pointing_missing_pointing_mode(self, dl3_writer): - dl3_writer._pointing_mode = None + def test_hdu_header_pointing_missing_pointing_mode(self, dl3_writer, dl3_data): + object.__setattr__(dl3_data, "pointing_mode", None) with pytest.raises(ValueError): - dl3_writer.get_hdu_header_base_pointing() + dl3_writer.get_hdu_header_base_pointing(dl3_data) - def test_hdu_header_pointing_missing_location(self, dl3_writer): - dl3_writer._location = None + def test_hdu_header_pointing_missing_location(self, dl3_writer, dl3_data): + object.__setattr__(dl3_data, "location", None) with pytest.raises(ValueError): - dl3_writer.get_hdu_header_base_pointing() + dl3_writer.get_hdu_header_base_pointing(dl3_data) - def test_hdu_header_events_hdu(self, dl3_writer): - header = dl3_writer.get_hdu_header_events() + def test_hdu_header_events_hdu(self, dl3_writer, dl3_data): + header = dl3_writer.get_hdu_header_events(dl3_data) assert header["HDUCLASS"] == "GADF" assert header["HDUCLAS1"] == "EVENTS" @@ -557,8 +613,8 @@ def test_hdu_header_events_hdu(self, dl3_writer): ]: assert key in header - def test_hdu_header_gti_hdu(self, dl3_writer): - header = dl3_writer.get_hdu_header_gti() + def test_hdu_header_gti_hdu(self, dl3_writer, dl3_data): + header = dl3_writer.get_hdu_header_gti(dl3_data) for key in [ "MJDREFI", @@ -581,8 +637,8 @@ def test_hdu_header_gti_hdu(self, dl3_writer): assert header["HDUCLASS"] == "GADF" assert header["HDUCLAS1"] == "GTI" - def test_hdu_header_pointing_hdu(self, dl3_writer): - header = dl3_writer.get_hdu_header_pointing() + def test_hdu_header_pointing_hdu(self, dl3_writer, dl3_data): + header = dl3_writer.get_hdu_header_pointing(dl3_data) assert header["HDUCLASS"] == "GADF" assert header["HDUCLAS1"] == "POINTING" @@ -598,8 +654,8 @@ def test_hdu_header_pointing_hdu(self, dl3_writer): for key in ["RA_PNT", "DEC_PNT", "ALT_PNT", "AZ_PNT", "OBS_ID"]: assert key in header - def test_column_renaming(self, dl3_writer): - events = dl3_writer.events + def test_column_renaming(self, dl3_writer, dl3_data): + events = dl3_data.events renamed = dl3_writer.transform_events_columns_for_gadf_format(events) assert renamed.colnames == ["EVENT_ID", "TIME", "RA", "DEC", "ENERGY"] @@ -619,14 +675,14 @@ def test_column_renaming(self, dl3_writer): with pytest.raises(ValueError, match="Required column reco_energy is missing"): dl3_writer.transform_events_columns_for_gadf_format(bad_events) - def test_gti_table(self, dl3_writer, dl2_meta_for_dl3): - gti_table = dl3_writer.create_gti_table() + def test_gti_table(self, dl3_writer, dl3_data, dl2_meta_for_dl3): + gti_table = dl3_writer.create_gti_table(dl3_data) assert gti_table.colnames == ["START", "STOP"] assert len(gti_table) == len(dl2_meta_for_dl3["gti"]) - def test_pointing_table(self, dl3_writer): - pointing_table = dl3_writer.create_pointing_table() + def test_pointing_table(self, dl3_writer, dl3_data): + pointing_table = dl3_writer.create_pointing_table(dl3_data) assert pointing_table.colnames == [ "TIME", @@ -650,19 +706,19 @@ def test_pointing_table(self, dl3_writer): ) assert np.all(np.isfinite(pointing_table["RA_PNT"].to_value(u.deg))) - def test_pointing_table_missing_pointing(self, dl3_writer): - dl3_writer._pointing = None + def test_pointing_table_missing_pointing(self, dl3_writer, dl3_data): + object.__setattr__(dl3_data, "pointing", None) with pytest.raises(ValueError): - dl3_writer.create_pointing_table() + dl3_writer.create_pointing_table(dl3_data) - def test_pointing_table_missing_location(self, dl3_writer): - dl3_writer._location = None + def test_pointing_table_missing_location(self, dl3_writer, dl3_data): + object.__setattr__(dl3_data, "location", None) with pytest.raises(ValueError): - dl3_writer.create_pointing_table() + dl3_writer.create_pointing_table(dl3_data) - def test_gti_table_is_sorted(self, dl3_writer, dl2_meta_for_dl3): + def test_gti_table_is_sorted(self, dl3_writer, dl3_data, dl2_meta_for_dl3): """Regression test: GTI table must be sorted by START (bug #1.3).""" - original_gti = dl3_writer.gti + original_gti = dl3_data.gti # Build GTI intervals in reverse chronological order ref = Time("2020-06-01T00:00:00", scale="tai") @@ -671,13 +727,13 @@ def test_gti_table_is_sorted(self, dl3_writer, dl2_meta_for_dl3): (ref + 100 * u.s, ref + 200 * u.s), (ref + 0 * u.s, ref + 100 * u.s), ] - dl3_writer.gti = reversed_gti + dl3_data.gti = reversed_gti - gti_table = dl3_writer.create_gti_table() + gti_table = dl3_writer.create_gti_table(dl3_data) start_values = gti_table["START"].to_value(u.s) assert np.all(np.diff(start_values) >= 0), ( "GTI START column must be sorted in ascending order" ) # Restore original GTI - dl3_writer.gti = original_gti + dl3_data.gti = original_gti