3434from astropy .stats import sigma_clip
3535from astropy .table import Table
3636from hist import Hist
37+ from traitlets import TraitError
3738
3839from ..containers import ChunkStatisticsContainer
3940from ..core import Component
40- from ..core .traits import AstroQuantity , Bool , ComponentName , Enum , Int
41+ from ..core .traits import AstroQuantity , Bool , ComponentName , Dict , Enum , Int , List
4142
4243
4344class BaseChunking (Component , metaclass = ABCMeta ):
@@ -400,39 +401,87 @@ class HistogramsAggregator(BaseAggregator):
400401 Aggregation is performed along axis=0 (the event dimension) for any N-dimensional data.
401402 """
402403
403- def __init__ (self , hist_axis , config = None , parent = None , ** kwargs ):
404+ hist_axis_list = List (
405+ trait = Dict (),
406+ allow_none = False ,
407+ help = (
408+ "List of histogram axis definitions. Each entry must contain "
409+ "``axis_class_name`` and ``kwargs`` and is used to construct a "
410+ "``hist.axis.<axis_class_name>(**kwargs)`` instance. If a single "
411+ "axis is provided, it is applied to all pixels/channels. If multiple "
412+ "axes are provided, they are applied per gain channel or first data "
413+ "dimension. In multi-axis mode, all axes must have the same number of "
414+ "bins. E.g. ``[{'axis_class_name': 'Regular', 'kwargs': {'bins': 40, 'start': 20.0, 'stop': 80.0}}]``."
415+ ),
416+ ).tag (config = True )
417+
418+ def _axis_from_dict (self , axis_config , entry_index ):
419+ """Create a hist axis from one dict in ``hist_axis_list``."""
420+ missing_keys = {"axis_class_name" , "kwargs" } - axis_config .keys ()
421+ if missing_keys :
422+ raise TraitError (
423+ f"Entry '{ entry_index } ' in the ``hist_axis_list`` trait "
424+ f"is missing required key(s): { ', ' .join (sorted (missing_keys ))} "
425+ )
426+
427+ axis_kwargs = axis_config ["kwargs" ]
428+ if not isinstance (axis_kwargs , dict ):
429+ raise TraitError (
430+ f"Entry '{ entry_index } ' in the ``hist_axis_list`` trait has "
431+ "a non-dict 'kwargs' value."
432+ )
433+
434+ axis_class_name = axis_config ["axis_class_name" ]
435+ axis_class = getattr (hist .axis , axis_class_name , None )
436+ if axis_class is None or not callable (axis_class ):
437+ raise TraitError (
438+ f"Entry '{ entry_index } ' in the ``hist_axis_list`` trait has "
439+ f"unknown axis_class_name '{ axis_class_name } '."
440+ )
441+
442+ try :
443+ return axis_class (** axis_kwargs )
444+ except TypeError as err :
445+ raise TraitError (
446+ f"Failed to initialize hist.axis.{ axis_class_name } for entry "
447+ f"'{ entry_index } ' with kwargs={ axis_kwargs } : { err } "
448+ ) from err
449+
450+ def __init__ (self , config = None , parent = None , ** kwargs ):
404451 """
405452 Parameters
406453 ----------
407- hist_axis : hist.axis or hist.Hist or list[hist.Hist]
408- Histogram definition for aggregation.
409- If a `hist.axis` is passed, one histogram is used for all channels.
410- If a `hist.Hist` or list of `hist.Hist` is passed, the first axis of
411- each hist defines the value axis. A list with length > 1 must match
412- the first dimension in `data.shape[1:]` (e.g. gain channels).
454+ hist_axis_list : list[dict]
455+ List of axis definitions. Each entry must contain
456+ ``axis_class_name`` and ``kwargs`` and is used to construct a
457+ ``hist.axis`` via ``hist.axis.<axis_class_name>(**kwargs)``.
413458 config : traitlets.loader.Config
414459 Configuration specified by config file or cmdline arguments
415460 parent : ctapipe.core.Component or ctapipe.core.Tool
416461 Parent of this component in the configuration hierarchy
417462 """
418463 super ().__init__ (config = config , parent = parent , ** kwargs )
419464
420- self .hist_axis = None
421- self .hist_templates = None
465+ axis_list = [
466+ self ._axis_from_dict (axis_config , index )
467+ for index , axis_config in enumerate (self .hist_axis_list )
468+ ]
469+ if len (axis_list ) == 0 :
470+ raise TraitError ("``hist_axis_list`` must contain at least one axis." )
422471
423- if isinstance (hist_axis , list ):
424- if len (hist_axis ) == 0 :
425- raise ValueError ("hist_axis list must not be empty" )
426- if not all (isinstance (h , Hist ) for h in hist_axis ):
427- raise TypeError ("All elements of hist_axis list must be hist.Hist" )
428- self .hist_templates = hist_axis
429- elif isinstance (hist_axis , Hist ):
430- self .hist_templates = [hist_axis ]
431- else :
432- self .hist_axis = hist_axis
472+ self .hist_axis = axis_list [0 ]
473+ self .hist_templates = None
474+ if len (axis_list ) > 1 :
475+ self .hist_templates = [Hist (axis ) for axis in axis_list ]
433476
434477 def _get_hist_templates_for_shape (self , spatial_shape ):
435- """Return one hist template per channel and validate compatibility."""
478+ """
479+ Return one histogram template per gain channel or first spatial dimension.
480+
481+ A single configured axis is reused for all channels. When multiple axes
482+ are configured, the number of axes must match the first data dimension
483+ and all axes must have the same bin count.
484+ """
436485 if len (spatial_shape ) == 0 :
437486 n_channels = 1
438487 else :
@@ -490,6 +539,22 @@ def _build_data_mask(self, data, masked_elements_of_sample):
490539 invalid = ~ np .isfinite (data )
491540 return mask | invalid
492541
542+ def _iter_channel_views (self , data , mask , n_events , spatial_shape ):
543+ """Yield per-channel data and mask views for histogram filling."""
544+ if len (spatial_shape ) == 0 :
545+ yield data , mask
546+ return
547+
548+ for channel in range (spatial_shape [0 ]):
549+ yield data [:, channel , ...], mask [:, channel , ...]
550+
551+ def _combine_edges (self , edges_per_channel ):
552+ """Return either a single edge array or a stacked per-channel array."""
553+ if len (edges_per_channel ) == 1 :
554+ return edges_per_channel [0 ]
555+
556+ return np .stack (edges_per_channel , axis = 0 )
557+
493558 def _compute_single_histos (self , data , mask , n_events , spatial_shape ):
494559 """Compute histograms using one value axis for all pixels/channels."""
495560 n_pixels = int (np .prod (spatial_shape ))
@@ -519,7 +584,6 @@ def _compute_single_histos(self, data, mask, n_events, spatial_shape):
519584 def _compute_multi_histos (self , data , mask , n_events , spatial_shape ):
520585 """Compute histograms with one template per channel."""
521586 templates = self ._get_hist_templates_for_shape (spatial_shape )
522- n_channels = 1 if len (spatial_shape ) == 0 else spatial_shape [0 ]
523587 channel_shape = spatial_shape [1 :] if len (spatial_shape ) > 1 else ()
524588 n_pixels_per_channel = (
525589 int (np .prod (channel_shape )) if len (channel_shape ) > 0 else 1
@@ -530,9 +594,10 @@ def _compute_multi_histos(self, data, mask, n_events, spatial_shape):
530594 n_events_valid = np .zeros (spatial_shape , dtype = int )
531595 edges_per_channel = []
532596 hist_objects = []
533- is_scalar = len (spatial_shape ) == 0
534597
535- for channel in range (n_channels ):
598+ for channel , (channel_data , channel_mask ) in enumerate (
599+ self ._iter_channel_views (data , mask , n_events , spatial_shape )
600+ ):
536601 template = templates [channel ]
537602 channel_hist = Hist (
538603 template .axes [0 ],
@@ -542,13 +607,6 @@ def _compute_multi_histos(self, data, mask, n_events, spatial_shape):
542607 hist_objects .append (channel_hist )
543608 edges_per_channel .append (channel_hist .axes [0 ].edges )
544609
545- if is_scalar :
546- channel_data = data
547- channel_mask = mask
548- else :
549- channel_data = data [:, channel , ...]
550- channel_mask = mask [:, channel , ...]
551-
552610 flat_channel_data = channel_data .reshape (n_events , n_pixels_per_channel )
553611 flat_channel_mask = channel_mask .reshape (n_events , n_pixels_per_channel )
554612
@@ -563,18 +621,10 @@ def _compute_multi_histos(self, data, mask, n_events, spatial_shape):
563621 channel_counts = channel_hist .values ().reshape ((n_bins ,) + channel_shape )
564622 valid_events = np .sum (~ flat_channel_mask , axis = 0 ).reshape (channel_shape )
565623
566- if is_scalar :
567- hist_counts [...] = channel_counts
568- n_events_valid [...] = valid_events
569- else :
570- hist_counts [:, channel , ...] = channel_counts
571- n_events_valid [channel , ...] = valid_events
624+ hist_counts [:, channel , ...] = channel_counts
625+ n_events_valid [channel , ...] = valid_events
572626
573- edges = (
574- np .stack (edges_per_channel , axis = 0 )
575- if len (edges_per_channel ) > 1
576- else edges_per_channel [0 ]
577- )
627+ edges = self ._combine_edges (edges_per_channel )
578628 return hist_objects , hist_counts , edges , n_events_valid
579629
580630 def compute_histos (
0 commit comments