Skip to content

Commit 33c7e35

Browse files
committed
make list of hist.axis object configurable
1 parent a613d14 commit 33c7e35

3 files changed

Lines changed: 237 additions & 76 deletions

File tree

examples/tutorials/histograms_aggregation.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
2. Configure and run HistogramsAggregator in chunks.
99
3. Access counts, bin edges, and valid-event counts (n_events).
1010
4. Plot one pixel histogram from the selected chunks and both gain channels for both image and peak_time columns.
11+
5. Show how the plotted histograms follow the configured ``hist.axis`` definitions.
1112
"""
1213

1314
import matplotlib.pyplot as plt
1415
import numpy as np
15-
import hist
1616
from astropy.table import Table
1717
from astropy.time import Time
1818
from traitlets.config import Config
@@ -65,17 +65,31 @@
6565
{
6666
"HistogramsAggregator": {
6767
"chunking_type": "SizeChunking",
68+
"hist_axis_list": [
69+
{
70+
"axis_class_name": "Regular",
71+
"kwargs": {
72+
"bins": 50,
73+
"start": 40.0,
74+
"stop": 110.0,
75+
"name": "value",
76+
},
77+
},
78+
{
79+
"axis_class_name": "Regular",
80+
"kwargs": {
81+
"bins": 50,
82+
"start": 20.0,
83+
"stop": 90.0,
84+
"name": "value",
85+
},
86+
},
87+
],
6888
},
6989
"SizeChunking": {"chunk_size": 1000},
7090
}
7191
)
72-
73-
image_hist_templates = [
74-
hist.Hist(hist.axis.Regular(50, 40.0, 110.0, name="value")),
75-
hist.Hist(hist.axis.Regular(50, 20.0, 90.0, name="value")),
76-
]
77-
78-
aggregator_image = HistogramsAggregator(image_hist_templates, config=config_image)
92+
aggregator_image = HistogramsAggregator(config=config_image)
7993
result = aggregator_image(
8094
table=table,
8195
col_name="image",
@@ -86,15 +100,23 @@
86100
{
87101
"HistogramsAggregator": {
88102
"chunking_type": "SizeChunking",
103+
"hist_axis_list": [
104+
{
105+
"axis_class_name": "Regular",
106+
"kwargs": {
107+
"bins": 50,
108+
"start": 2.0,
109+
"stop": 38.0,
110+
"name": "value",
111+
},
112+
}
113+
],
89114
},
90115
"SizeChunking": {"chunk_size": 1000},
91116
}
92117
)
93118

94-
aggregator_peak_time = HistogramsAggregator(
95-
hist.axis.Regular(50, 2.0, 38.0, name="value"),
96-
config=config_peak_time,
97-
)
119+
aggregator_peak_time = HistogramsAggregator(config=config_peak_time)
98120
result_peak_time = aggregator_peak_time(
99121
table=table,
100122
col_name="peak_time",
@@ -110,6 +132,7 @@
110132
# -------------------------------------------------------------------
111133
# Inspect one pixel histogram in both chunks and both gain channels
112134
# -------------------------------------------------------------------
135+
# The plotted curves use the axes configured above via ``hist_axis_list``.
113136
pixel_index = 10
114137
gain_label = {0: "High Gain", 1: "Low Gain"}
115138

src/ctapipe/monitoring/aggregator.py

Lines changed: 92 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@
3434
from astropy.stats import sigma_clip
3535
from astropy.table import Table
3636
from hist import Hist
37+
from traitlets import TraitError
3738

3839
from ..containers import ChunkStatisticsContainer
3940
from ..core import Component
40-
from ..core.traits import AstroQuantity, Bool, ComponentName, Enum, Int
41+
from ..core.traits import AstroQuantity, Bool, ComponentName, Dict, Enum, Int, List
4142

4243

4344
class BaseChunking(Component, metaclass=ABCMeta):
@@ -400,39 +401,87 @@ class HistogramsAggregator(BaseAggregator):
400401
Aggregation is performed along axis=0 (the event dimension) for any N-dimensional data.
401402
"""
402403

403-
def __init__(self, hist_axis, config=None, parent=None, **kwargs):
404+
hist_axis_list = List(
405+
trait=Dict(),
406+
allow_none=False,
407+
help=(
408+
"List of histogram axis definitions. Each entry must contain "
409+
"``axis_class_name`` and ``kwargs`` and is used to construct a "
410+
"``hist.axis.<axis_class_name>(**kwargs)`` instance. If a single "
411+
"axis is provided, it is applied to all pixels/channels. If multiple "
412+
"axes are provided, they are applied per gain channel or first data "
413+
"dimension. In multi-axis mode, all axes must have the same number of "
414+
"bins. E.g. ``[{'axis_class_name': 'Regular', 'kwargs': {'bins': 40, 'start': 20.0, 'stop': 80.0}}]``."
415+
),
416+
).tag(config=True)
417+
418+
def _axis_from_dict(self, axis_config, entry_index):
419+
"""Create a hist axis from one dict in ``hist_axis_list``."""
420+
missing_keys = {"axis_class_name", "kwargs"} - axis_config.keys()
421+
if missing_keys:
422+
raise TraitError(
423+
f"Entry '{entry_index}' in the ``hist_axis_list`` trait "
424+
f"is missing required key(s): {', '.join(sorted(missing_keys))}"
425+
)
426+
427+
axis_kwargs = axis_config["kwargs"]
428+
if not isinstance(axis_kwargs, dict):
429+
raise TraitError(
430+
f"Entry '{entry_index}' in the ``hist_axis_list`` trait has "
431+
"a non-dict 'kwargs' value."
432+
)
433+
434+
axis_class_name = axis_config["axis_class_name"]
435+
axis_class = getattr(hist.axis, axis_class_name, None)
436+
if axis_class is None or not callable(axis_class):
437+
raise TraitError(
438+
f"Entry '{entry_index}' in the ``hist_axis_list`` trait has "
439+
f"unknown axis_class_name '{axis_class_name}'."
440+
)
441+
442+
try:
443+
return axis_class(**axis_kwargs)
444+
except TypeError as err:
445+
raise TraitError(
446+
f"Failed to initialize hist.axis.{axis_class_name} for entry "
447+
f"'{entry_index}' with kwargs={axis_kwargs}: {err}"
448+
) from err
449+
450+
def __init__(self, config=None, parent=None, **kwargs):
404451
"""
405452
Parameters
406453
----------
407-
hist_axis : hist.axis or hist.Hist or list[hist.Hist]
408-
Histogram definition for aggregation.
409-
If a `hist.axis` is passed, one histogram is used for all channels.
410-
If a `hist.Hist` or list of `hist.Hist` is passed, the first axis of
411-
each hist defines the value axis. A list with length > 1 must match
412-
the first dimension in `data.shape[1:]` (e.g. gain channels).
454+
hist_axis_list : list[dict]
455+
List of axis definitions. Each entry must contain
456+
``axis_class_name`` and ``kwargs`` and is used to construct a
457+
``hist.axis`` via ``hist.axis.<axis_class_name>(**kwargs)``.
413458
config : traitlets.loader.Config
414459
Configuration specified by config file or cmdline arguments
415460
parent : ctapipe.core.Component or ctapipe.core.Tool
416461
Parent of this component in the configuration hierarchy
417462
"""
418463
super().__init__(config=config, parent=parent, **kwargs)
419464

420-
self.hist_axis = None
421-
self.hist_templates = None
465+
axis_list = [
466+
self._axis_from_dict(axis_config, index)
467+
for index, axis_config in enumerate(self.hist_axis_list)
468+
]
469+
if len(axis_list) == 0:
470+
raise TraitError("``hist_axis_list`` must contain at least one axis.")
422471

423-
if isinstance(hist_axis, list):
424-
if len(hist_axis) == 0:
425-
raise ValueError("hist_axis list must not be empty")
426-
if not all(isinstance(h, Hist) for h in hist_axis):
427-
raise TypeError("All elements of hist_axis list must be hist.Hist")
428-
self.hist_templates = hist_axis
429-
elif isinstance(hist_axis, Hist):
430-
self.hist_templates = [hist_axis]
431-
else:
432-
self.hist_axis = hist_axis
472+
self.hist_axis = axis_list[0]
473+
self.hist_templates = None
474+
if len(axis_list) > 1:
475+
self.hist_templates = [Hist(axis) for axis in axis_list]
433476

434477
def _get_hist_templates_for_shape(self, spatial_shape):
435-
"""Return one hist template per channel and validate compatibility."""
478+
"""
479+
Return one histogram template per gain channel or first spatial dimension.
480+
481+
A single configured axis is reused for all channels. When multiple axes
482+
are configured, the number of axes must match the first data dimension
483+
and all axes must have the same bin count.
484+
"""
436485
if len(spatial_shape) == 0:
437486
n_channels = 1
438487
else:
@@ -490,6 +539,22 @@ def _build_data_mask(self, data, masked_elements_of_sample):
490539
invalid = ~np.isfinite(data)
491540
return mask | invalid
492541

542+
def _iter_channel_views(self, data, mask, n_events, spatial_shape):
543+
"""Yield per-channel data and mask views for histogram filling."""
544+
if len(spatial_shape) == 0:
545+
yield data, mask
546+
return
547+
548+
for channel in range(spatial_shape[0]):
549+
yield data[:, channel, ...], mask[:, channel, ...]
550+
551+
def _combine_edges(self, edges_per_channel):
552+
"""Return either a single edge array or a stacked per-channel array."""
553+
if len(edges_per_channel) == 1:
554+
return edges_per_channel[0]
555+
556+
return np.stack(edges_per_channel, axis=0)
557+
493558
def _compute_single_histos(self, data, mask, n_events, spatial_shape):
494559
"""Compute histograms using one value axis for all pixels/channels."""
495560
n_pixels = int(np.prod(spatial_shape))
@@ -519,7 +584,6 @@ def _compute_single_histos(self, data, mask, n_events, spatial_shape):
519584
def _compute_multi_histos(self, data, mask, n_events, spatial_shape):
520585
"""Compute histograms with one template per channel."""
521586
templates = self._get_hist_templates_for_shape(spatial_shape)
522-
n_channels = 1 if len(spatial_shape) == 0 else spatial_shape[0]
523587
channel_shape = spatial_shape[1:] if len(spatial_shape) > 1 else ()
524588
n_pixels_per_channel = (
525589
int(np.prod(channel_shape)) if len(channel_shape) > 0 else 1
@@ -530,9 +594,10 @@ def _compute_multi_histos(self, data, mask, n_events, spatial_shape):
530594
n_events_valid = np.zeros(spatial_shape, dtype=int)
531595
edges_per_channel = []
532596
hist_objects = []
533-
is_scalar = len(spatial_shape) == 0
534597

535-
for channel in range(n_channels):
598+
for channel, (channel_data, channel_mask) in enumerate(
599+
self._iter_channel_views(data, mask, n_events, spatial_shape)
600+
):
536601
template = templates[channel]
537602
channel_hist = Hist(
538603
template.axes[0],
@@ -542,13 +607,6 @@ def _compute_multi_histos(self, data, mask, n_events, spatial_shape):
542607
hist_objects.append(channel_hist)
543608
edges_per_channel.append(channel_hist.axes[0].edges)
544609

545-
if is_scalar:
546-
channel_data = data
547-
channel_mask = mask
548-
else:
549-
channel_data = data[:, channel, ...]
550-
channel_mask = mask[:, channel, ...]
551-
552610
flat_channel_data = channel_data.reshape(n_events, n_pixels_per_channel)
553611
flat_channel_mask = channel_mask.reshape(n_events, n_pixels_per_channel)
554612

@@ -563,18 +621,10 @@ def _compute_multi_histos(self, data, mask, n_events, spatial_shape):
563621
channel_counts = channel_hist.values().reshape((n_bins,) + channel_shape)
564622
valid_events = np.sum(~flat_channel_mask, axis=0).reshape(channel_shape)
565623

566-
if is_scalar:
567-
hist_counts[...] = channel_counts
568-
n_events_valid[...] = valid_events
569-
else:
570-
hist_counts[:, channel, ...] = channel_counts
571-
n_events_valid[channel, ...] = valid_events
624+
hist_counts[:, channel, ...] = channel_counts
625+
n_events_valid[channel, ...] = valid_events
572626

573-
edges = (
574-
np.stack(edges_per_channel, axis=0)
575-
if len(edges_per_channel) > 1
576-
else edges_per_channel[0]
577-
)
627+
edges = self._combine_edges(edges_per_channel)
578628
return hist_objects, hist_counts, edges, n_events_valid
579629

580630
def compute_histos(

0 commit comments

Comments
 (0)