Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,13 @@ Added:
and photon line if the flag `deconvolve` is set to True.
summary about a pulse pattern (!435).
- `StepTimer` class for measuring performance within functions.

- [Timepix3.select_trains][extra.components.Timepix3.select_trains] and
[Timepix3.split_trains][extra.components.Timepix3.split_trains]
- [XGM.select_trains][extra.components.XGM.select_trains] and
[XGM.split_trains][extra.components.XGM.split_trains]
- [DelayLineDetector.split_trains][extra.components.DelayLineDetector.split_trains]
- `split_trains()` methods for the pulse pattern components.

Fixed:

- [fit_gaussian()][extra.utils.fit_gaussian] would sometimes return a negative
Expand Down
6 changes: 6 additions & 0 deletions src/extra/components/dld.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd

from extra.data import KeyData, PropertyNameError
from extra_data.read_machinery import split_trains
from .utils import _isinstance_no_import


Expand Down Expand Up @@ -283,6 +284,11 @@ def select_trains(self, trains):

return new_self

def split_trains(self, parts=None, trains_per_part=None):
n_trains = len(self._instrument_src.train_ids)
for sl in split_trains(n_trains, parts=parts, trains_per_part=trains_per_part):
yield self.select_trains(sl)

def pulses(self, **kwargs):
"""Get pulse object based on internal triggers.

Expand Down
13 changes: 13 additions & 0 deletions src/extra/components/pulses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
PPL_BITS, DESTINATION_TLD, DESTINATION_T4D, DESTINATION_T5D, \
PHOTON_LINE_DEFLECTION
from extra_data import DataCollection, SourceData, KeyData, by_index, by_id
from extra_data.read_machinery import split_trains

from .utils import identify_sase, _instrument_to_sase

Expand Down Expand Up @@ -360,6 +361,18 @@ def select_trains(self, trains):

return res

def split_trains(self, parts=None, trains_per_part=None):
# Get train IDs from an EXtra-data object if possible, so it will split
# the same way as other things from the same data.
if self._source is not None:
tids = self._source.train_ids
elif self._key is not None:
tids = self._key.train_ids
else:
tids = self._get_train_ids() # Used for ManualPulses
for sl in split_trains(len(tids), parts=parts, trains_per_part=trains_per_part):
yield self.select_trains(sl)

def select_pulses(self, pulse_sel, train_sel=None,
reset_index=True) -> ManualPulses:
"""Select a subset of pulses.
Expand Down
40 changes: 40 additions & 0 deletions src/extra/components/timepix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from collections import defaultdict
from copy import copy
from io import IOBase
from os import PathLike
import re
Expand All @@ -8,6 +9,7 @@
import pandas as pd

from extra_data import by_id
from extra_data.read_machinery import split_trains


# Pulse indices used internally for virtual leading and trailing pulses.
Expand Down Expand Up @@ -56,6 +58,7 @@ def __init__(self, data, detector=None, pulses=None, **kwargs):

self._raw_control_src = None
self._raw_instrument_src = None
self._raw_instrument_dc = None
self._centroids_control_src = None
self._centroids_instrument_src = None

Expand Down Expand Up @@ -108,6 +111,36 @@ def __repr__(self):
return "<{} {}: {}>".format(type(self).__name__, self._detector_name,
', '.join(data_labels))

def select_trains(self, trains):
"""Select a subset of trains in this data.

This method accepts the same type of arguments as
[DataCollection.select_trains][extra_data.DataCollection.select_trains].
"""
res = copy(self)

if self._raw_control_src is not None:
res._raw_control_src = self._raw_control_src.select_trains(trains)
if self._raw_instrument_src is not None:
res._raw_instrument_src = self._raw_instrument_src.select_trains(trains)
if self._raw_instrument_dc is not None:
raw_train_ids = res.raw_size_key.drop_empty_trains().train_ids
res._raw_instrument_dc = self._raw_instrument_dc.select_trains(
by_id[raw_train_ids]
)
if self._centroids_control_src is not None:
res._centroids_control_src = self._centroids_control_src.select_trains(trains)
if self._centroids_instrument_src is not None:
res._centroids_instrument_src = self._centroids_instrument_src.select_trains(trains)
res._pulses = self._pulses.select_trains(trains)

return res

def split_trains(self, parts=None, trains_per_part=None):
n_trains = len(self.train_ids)
for sl in split_trains(n_trains, parts=parts, trains_per_part=trains_per_part):
yield self.select_trains(sl)

def _has_centroid_labels(self):
"""Whether centroid labels are available."""
return (self._centroids_instrument_src is not None and
Expand Down Expand Up @@ -402,6 +435,13 @@ def centroids_instrument_src(self):

return self._centroids_instrument_src

@property
def train_ids(self):
# One or other of the instrument sources must be present
if self._raw_instrument_src is not None:
return self._raw_instrument_src.train_ids
return self._centroids_instrument_src.train_ids

@property
def centroids_key(self):
return self.centroids_instrument_src['data.centroids']
Expand Down
30 changes: 30 additions & 0 deletions src/extra/components/xgm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from copy import copy
from enum import Enum
from warnings import warn
from textwrap import dedent

import numpy as np

from extra_data import SourceData, KeyData, MultiRunError
from extra_data.read_machinery import split_trains

from .. import ureg
from .utils import SASE_TOPICS, identify_sase
Expand Down Expand Up @@ -216,6 +218,34 @@ def instrument_source(self) -> SourceData:
(e.g. `SA2_XTD1_XGM/XGM/DOOCS:output`)."""
return self._instrument_source

def select_trains(self, trains):
"""Select a subset of trains in this data.

This method accepts the same type of arguments as
[DataCollection.select_trains][extra_data.DataCollection.select_trains].
"""
res = copy(self)
res._control_source = self._control_source.select_trains(trains)
res._instrument_source = self._instrument_source.select_trains(trains)

# Clear all cached data except single values which would be the same in
# the subset.
res._wavelength_by_train = None
res._photon_energy_by_train = None
res._photon_flux = None
res._pulse_energy = {}
res._slow_train_energy = {}
res._pulse_counts = {}
res._slow_pulse_counts = {}
res._max_pulses = {}

return res

def split_trains(self, parts=None, trains_per_part=None):
n_trains = len(self._control_source.train_ids)
for sl in split_trains(n_trains, parts=parts, trains_per_part=trains_per_part):
yield self.select_trains(sl)

def wavelength(self, with_units=True):
"""The nominal wavelength of the X-rays in nanometers.

Expand Down
23 changes: 18 additions & 5 deletions tests/mockdata/dld.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def write_instrument(self, f):
for key, dtype, entry_shape in keys:
f.create_dataset(f'INSTRUMENT/{source}/{index_group}/{key}',
maxshape=(None,) + entry_shape, dtype=dtype,
shape=(num_entries,) + entry_shape)
shape=(num_entries,) + entry_shape,
chunks=(min(num_entries, 50),) + entry_shape,
)

data_root = f[f'INSTRUMENT/{source}']

Expand All @@ -88,10 +90,21 @@ def write_instrument(self, f):
triggers['ppl'][:num_entries] = (pulse_ids % 2) == 0

data_root['raw/triggers'][:num_entries] = triggers
data_root['raw/edges'][:num_entries] = np.nan
data_root['raw/amplitudes'][:num_entries] = np.nan
data_root['rec/signals'][:num_entries] = np.nan
data_root['rec/hits'][:num_entries] = (np.nan, np.nan, np.nan, -1)
# Workaround: h5py's guess for how to broadcast a single value is
# abysmally slow for these dataset shapes. Make an array the same shape
# first to bypass it.
data_root['raw/edges'][:num_entries] = np.full(
(num_entries, 7, self.max_rows), np.nan
)
data_root['raw/amplitudes'][:num_entries] = np.full(
(num_entries, 7, self.max_rows), np.nan
)
data_root['rec/signals'][:num_entries] = np.full(
(num_entries, self.max_rows), np.nan, dtype=signal_dt
)
data_root['rec/hits'][:num_entries] = np.full(
(num_entries, self.max_rows), np.array((np.nan,)*3 + (-1,), dtype=hit_dt)
)

for i, ch in product(range(5), range(7)):
# Alternate between 1-channel edges per pulse with random
Expand Down
6 changes: 6 additions & 0 deletions tests/test_components_dld.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,9 @@ def test_dld_max_method(mock_sqs_remi_run):
all_hits[mask], dld.hits(max_method=m))
pd.testing.assert_frame_equal(
all_signals[mask], dld.signals(max_method=m))


def test_split_trains(mock_sqs_remi_run):
dld = DelayLineDetector(mock_sqs_remi_run, 'SQS_REMI_DLD6/DET/TOP')
chunks = list(dld.split_trains(trains_per_part=30))
assert [len(c.instrument_source.train_ids) for c in chunks] == [25] * 4
6 changes: 6 additions & 0 deletions tests/test_components_pulses.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def test_select_trains(mock_spb_aux_run):
pulses.select_trains(np.s_[:20])


def test_split_trains(mock_spb_aux_run):
pulses = XrayPulses(mock_spb_aux_run.select('SPB*'))
chunks = list(pulses.split_trains(trains_per_part=30))
assert [len(c.source.train_ids) for c in chunks] == [25] * 4


@pytest.mark.parametrize('source', **pattern_sources)
def test_pulse_mask(mock_spb_aux_run, source):
run = mock_spb_aux_run
Expand Down
6 changes: 6 additions & 0 deletions tests/test_components_timepix.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,9 @@ def test_timepix3_centroid_events(mock_sqs_timepix_run):
np.testing.assert_array_equal(
d['centroid_size'], np.linspace(1, 10, N, dtype=np.int16))
np.testing.assert_array_equal(d['label'], np.arange(N))


def test_split_trains(mock_sqs_timepix_run):
tpx = Timepix3(mock_sqs_timepix_run.deselect('SQS_EXTRA*'))
chunks = list(tpx.split_trains(trains_per_part=30))
assert [len(c.raw_control_src.train_ids) for c in chunks] == [25] * 4
6 changes: 6 additions & 0 deletions tests/test_components_xgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,9 @@ def test_wrong_pulse_counts(mock_spb_aux_run):
# Otherwise it should return the slow data counts which should have
# a name.
assert xgm.pulse_counts(force_slow_data=True).name == mock_pulse_counts.name


def test_split_trains(mock_spb_aux_run):
xgm = XGM(mock_spb_aux_run)
chunks = list(xgm.split_trains(trains_per_part=30))
assert [len(c.instrument_source.train_ids) for c in chunks] == [25] * 4
Loading