Skip to content

Commit b1761d8

Browse files
committed
Remove one level of indirection from DVR classes
1 parent 7e40bd2 commit b1761d8

1 file changed

Lines changed: 24 additions & 45 deletions

File tree

src/ctapipe/image/reducer.py

Lines changed: 24 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from abc import abstractmethod
6+
from typing import override
67

78
import numpy as np
89

@@ -18,7 +19,11 @@
1819
from ctapipe.image.cleaning import dilate
1920
from ctapipe.image.extractor import ImageExtractor
2021

21-
__all__ = ["DataVolumeReducer", "NullDataVolumeReducer", "TailCutsDataVolumeReducer"]
22+
__all__ = [
23+
"DataVolumeReducer",
24+
"NullDataVolumeReducer",
25+
"TailCutsDataVolumeReducer",
26+
]
2227

2328

2429
class DataVolumeReducer(TelescopeComponent):
@@ -41,39 +46,12 @@ def __init__(self, subarray, config=None, parent=None, **kwargs):
4146
self.subarray = subarray
4247
super().__init__(config=config, parent=parent, subarray=subarray, **kwargs)
4348

44-
def __call__(self, waveforms, tel_id=None, selected_gain_channel=None):
45-
"""
46-
Call the relevant functions to perform data volume reduction on the
47-
waveforms.
48-
49-
Parameters
50-
----------
51-
waveforms: ndarray
52-
Waveforms stored in a numpy array of shape
53-
(n_pix, n_samples).
54-
tel_id: int
55-
The telescope id. Required for the 'image_extractor' and
56-
'camera.geometry' in 'TailCutsDataVolumeReducer'.
57-
selected_gain_channel: ndarray
58-
The channel selected in the gain selection, per pixel. Required for
59-
the 'image_extractor' in 'TailCutsDataVolumeReducer'.
60-
extraction.
61-
62-
Returns
63-
-------
64-
mask: array
65-
Mask of selected pixels.
66-
"""
67-
mask = self.select_pixels(
68-
waveforms, tel_id=tel_id, selected_gain_channel=selected_gain_channel
69-
)
70-
return mask
71-
7249
@abstractmethod
73-
def select_pixels(self, waveforms, tel_id=None, selected_gain_channel=None):
50+
def __call__(
51+
self, waveforms, tel_id: int, selected_gain_channel=None
52+
) -> np.ndarray[bool]:
7453
"""
75-
Abstract method to be defined by a DataVolumeReducer subclass.
76-
Call the relevant functions for the required pixel selection.
54+
Select pixels of which to keep the waveform data.
7755
7856
Parameters
7957
----------
@@ -99,7 +77,8 @@ class NullDataVolumeReducer(DataVolumeReducer):
9977
Perform no data volume reduction
10078
"""
10179

102-
def select_pixels(self, waveforms, tel_id=None, selected_gain_channel=None):
80+
@override
81+
def __call__(self, waveforms, tel_id: int, selected_gain_channel=None):
10382
n_pixels = waveforms.shape[-2]
10483
return np.ones(n_pixels, dtype=bool)
10584

@@ -122,6 +101,16 @@ class TailCutsDataVolumeReducer(DataVolumeReducer):
122101
do_boundary_dilation: BoolTelescopeParameter
123102
If set to 'False', the iteration steps in 2) are skipped and
124103
normal TailcutCleaning is used.
104+
105+
Parameters
106+
----------
107+
subarray: ctapipe.instrument.SubarrayDescription
108+
Description of the subarray
109+
config: traitlets.loader.Config
110+
Configuration specified by config file or cmdline arguments.
111+
Used to set traitlet values.
112+
Set to None if no configuration to pass.
113+
kwargs
125114
"""
126115

127116
image_extractor_type = TelescopeParameter(
@@ -149,17 +138,6 @@ def __init__(
149138
image_extractor=None,
150139
**kwargs,
151140
):
152-
"""
153-
Parameters
154-
----------
155-
subarray: ctapipe.instrument.SubarrayDescription
156-
Description of the subarray
157-
config: traitlets.loader.Config
158-
Configuration specified by config file or cmdline arguments.
159-
Used to set traitlet values.
160-
Set to None if no configuration to pass.
161-
kwargs
162-
"""
163141
super().__init__(config=config, parent=parent, subarray=subarray, **kwargs)
164142

165143
if cleaner is None:
@@ -178,7 +156,8 @@ def __init__(
178156
self.image_extractor_type = [("type", "*", name)]
179157
self.image_extractors[name] = image_extractor
180158

181-
def select_pixels(self, waveforms, tel_id=None, selected_gain_channel=None):
159+
@override
160+
def __call__(self, waveforms, tel_id, selected_gain_channel=None):
182161
camera_geom = self.subarray.tel[tel_id].camera.geometry
183162
# Pulse-integrate waveforms
184163
extractor = self.image_extractors[self.image_extractor_type.tel[tel_id]]

0 commit comments

Comments
 (0)