Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions magicctapipe/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def load_magic_dl1_data_files(input_dir, config):


def load_train_data_files(
input_dir, offaxis_min=None, offaxis_max=None, true_event_class=None
input_dir, config, offaxis_min=None, offaxis_max=None, true_event_class=None
):
"""
Loads DL1-stereo data files and separates the shower events per
Expand All @@ -793,6 +793,8 @@ def load_train_data_files(
----------
input_dir : str
Path to a directory where input DL1-stereo files are stored
config : dict
Yaml file with information about the telescope IDs.
offaxis_min : str, optional
Minimum shower off-axis angle allowed, whose format should be
acceptable by `astropy.units.quantity.Quantity`
Expand All @@ -814,12 +816,8 @@ def load_train_data_files(
If any DL1-stereo data files are not found in the input
directory
"""
TEL_COMBINATIONS = {
"LST1_M1": [1, 2], # combo_type = 0
"LST1_M1_M2": [1, 2, 3], # combo_type = 1
"LST1_M2": [1, 3], # combo_type = 2
"M1_M2": [2, 3], # combo_type = 3
} # TODO: REMOVE WHEN SWITCHING TO THE NEW RFs IMPLEMENTTATION (1 RF PER TELESCOPE)
_, TEL_COMBINATIONS = telescope_combinations(config)
# TEL_COMBINATIONS = {k.replace("MAGIC-II", "M2").replace("MAGIC-I","M1").replace("LST-","LST"): v for k,v in TEL_COMBINATIONS.items()}

# Find the input files
file_mask = f"{input_dir}/dl1_stereo_*.h5"
Expand Down Expand Up @@ -858,7 +856,7 @@ def load_train_data_files(
if true_event_class is not None:
event_data["true_event_class"] = true_event_class

event_data = get_stereo_events_old(event_data, group_index=GROUP_INDEX_TRAIN)
event_data = get_stereo_events(event_data, config, group_index=GROUP_INDEX_TRAIN)

data_train = {}

Expand Down Expand Up @@ -959,14 +957,18 @@ def load_train_data_files_tel(
return data_train


def load_mc_dl2_data_file(input_file, quality_cuts, event_type, weight_type_dl2):
def load_mc_dl2_data_file(
input_file, config, quality_cuts, event_type, weight_type_dl2
):
"""
Loads a MC DL2 data file for creating the IRFs.

Parameters
----------
input_file : str
Path to an input MC DL2 data file
config : dict
Yaml file with information about the telescope IDs.
quality_cuts : str
Quality cuts applied to the input events
event_type : str
Expand Down Expand Up @@ -999,7 +1001,7 @@ def load_mc_dl2_data_file(input_file, quality_cuts, event_type, weight_type_dl2)
df_events.set_index(["obs_id", "event_id", "tel_id"], inplace=True)
df_events.sort_index(inplace=True)

df_events = get_stereo_events_old(df_events, quality_cuts)
df_events = get_stereo_events(df_events, config, quality_cuts)

logger.info(f"\nExtracting the events of the '{event_type}' type...")

Expand Down Expand Up @@ -1090,14 +1092,16 @@ def load_mc_dl2_data_file(input_file, quality_cuts, event_type, weight_type_dl2)
return event_table, pointing, sim_info


def load_dl2_data_file(input_file, quality_cuts, event_type, weight_type_dl2):
def load_dl2_data_file(input_file, config, quality_cuts, event_type, weight_type_dl2):
"""
Loads a DL2 data file for processing to DL3.

Parameters
----------
input_file : str
Path to an input DL2 data file
config : dict
Yaml file with information about the telescope IDs.
quality_cuts : str
Quality cuts applied to the input events
event_type : str
Expand Down Expand Up @@ -1130,7 +1134,7 @@ def load_dl2_data_file(input_file, quality_cuts, event_type, weight_type_dl2):
event_data.set_index(["obs_id", "event_id", "tel_id"], inplace=True)
event_data.sort_index(inplace=True)

event_data = get_stereo_events_old(event_data, quality_cuts)
event_data = get_stereo_events(event_data, config, quality_cuts)

logger.info(f"\nExtracting the events of the '{event_type}' type...")

Expand Down
49 changes: 27 additions & 22 deletions magicctapipe/io/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,32 +313,35 @@ def test_get_stereo_events_mc_cut(self, gamma_stereo, p_stereo, config_gen):
assert np.all(data["intensity"] > 50)
assert len(data) > 0

def test_load_train_data_files(self, p_stereo, gamma_stereo):
def test_load_train_data_files(self, p_stereo, gamma_stereo, config_gen):
"""
Check dictionary of the combo types
"""

for stereo in [p_stereo, gamma_stereo]:
events = load_train_data_files(str(stereo[0]))
events = load_train_data_files(str(stereo[0]), config_gen)
assert list(events.keys()) == ["LST1_M1", "LST1_M1_M2", "LST1_M2", "M1_M2"]
data = events["LST1_M1"]
assert np.all(data["combo_type"] == 0)
assert "off_axis" in data.columns
assert "true_event_class" not in data.columns

def test_load_train_data_files_off(self, gamma_stereo):
def test_load_train_data_files_off(self, gamma_stereo, config_gen):
"""
Check off-axis cut
"""
events = load_train_data_files(
str(gamma_stereo[0]), offaxis_min="0.2 deg", offaxis_max="0.5 deg"
str(gamma_stereo[0]),
config_gen,
offaxis_min="0.2 deg",
offaxis_max="0.5 deg",
)
data = events["LST1_M1"]
assert np.all(data["off_axis"] >= 0.2)
assert np.all(data["off_axis"] <= 0.5)
assert len(data) > 0

def test_load_train_data_files_exc(self, temp_train_exc):
def test_load_train_data_files_exc(self, temp_train_exc, config_gen):
"""
Check on exceptions
"""
Expand All @@ -347,7 +350,7 @@ def test_load_train_data_files_exc(self, temp_train_exc):
FileNotFoundError,
match="Could not find any DL1-stereo data files in the input directory.",
):
_ = load_train_data_files(str(temp_train_exc))
_ = load_train_data_files(str(temp_train_exc), config_gen)

def test_load_train_data_files_tel(self, p_stereo, gamma_stereo, config_gen):
"""
Expand Down Expand Up @@ -408,14 +411,14 @@ def test_exist_dl2_mc(p_dl2, gamma_dl2):

@pytest.mark.dependency(depends=["test_exist_dl2_mc"])
class TestDL2MC:
def test_load_mc_dl2_data_file(self, p_dl2, gamma_dl2):
def test_load_mc_dl2_data_file(self, p_dl2, gamma_dl2, config_gen):
"""
Checks on default loading
"""
dl2_mc = [p for p in gamma_dl2.glob("*")] + [p for p in p_dl2.glob("*")]
for file in dl2_mc:
data, point, _ = load_mc_dl2_data_file(
str(file), "width>0", "software", "simple"
str(file), config_gen, "width>0", "software", "simple"
)
assert "pointing_alt" in data.colnames
assert "theta" in data.colnames
Expand All @@ -424,31 +427,31 @@ def test_load_mc_dl2_data_file(self, p_dl2, gamma_dl2):
assert point[0] >= 0
assert point[0] <= 90

def test_load_mc_dl2_data_file_cut(self, p_dl2, gamma_dl2):
def test_load_mc_dl2_data_file_cut(self, config_gen, p_dl2, gamma_dl2):
"""
Check on quality cuts
"""
dl2_mc = [p for p in gamma_dl2.glob("*")] + [p for p in p_dl2.glob("*")]
for file in dl2_mc:
data, _, _ = load_mc_dl2_data_file(
str(file), "gammaness>0.1", "software", "simple"
str(file), config_gen, "gammaness>0.1", "software", "simple"
)
assert np.all(data["gammaness"] > 0.1)
assert len(data) > 0

def test_load_mc_dl2_data_file_opt(self, p_dl2, gamma_dl2):
def test_load_mc_dl2_data_file_opt(self, p_dl2, gamma_dl2, config_gen):
"""
Check on event_type
"""
dl2_mc = [p for p in gamma_dl2.glob("*")] + [p for p in p_dl2.glob("*")]
for file in dl2_mc:
data_s, _, _ = load_mc_dl2_data_file(
str(file), "width>0", "software", "simple"
str(file), config_gen, "width>0", "software", "simple"
)
assert np.all(data_s["combo_type"] < 3)
assert len(data_s) > 0

def test_load_mc_dl2_data_file_exc(self, p_dl2, gamma_dl2):
def test_load_mc_dl2_data_file_exc(self, p_dl2, gamma_dl2, config_gen):
"""
Check on event_type exceptions
"""
Expand All @@ -460,7 +463,7 @@ def test_load_mc_dl2_data_file_exc(self, p_dl2, gamma_dl2):
match=f"Unknown event type '{event_type}'.",
):
_, _, _ = load_mc_dl2_data_file(
str(file), "width>0", event_type, "simple"
str(file), config_gen, "width>0", event_type, "simple"
)

def test_get_dl2_mean_mc(self, p_dl2, gamma_dl2):
Expand Down Expand Up @@ -733,13 +736,13 @@ def test_exist_dl2(real_dl2):

@pytest.mark.dependency(depends=["test_exist_dl2"])
class TestDL2Data:
def test_load_dl2_data_file(self, real_dl2):
def test_load_dl2_data_file(self, real_dl2, config_gen):
"""
Checks on default loading
"""
for file in real_dl2.glob("*"):
data, on, dead = load_dl2_data_file(
str(file), "width>0", "software", "simple"
str(file), config_gen, "width>0", "software", "simple"
)
assert "pointing_alt" in data.colnames
assert "timestamp" in data.colnames
Expand All @@ -748,29 +751,29 @@ def test_load_dl2_data_file(self, real_dl2):
assert on > 0
assert dead > 0

def test_load_dl2_data_file_cut(self, real_dl2):
def test_load_dl2_data_file_cut(self, real_dl2, config_gen):
"""
Check on quality cuts
"""
for file in real_dl2.glob("*"):
data, _, _ = load_dl2_data_file(
str(file), "gammaness<0.9", "software", "simple"
str(file), config_gen, "gammaness<0.9", "software", "simple"
)
assert np.all(data["gammaness"] < 0.9)
assert len(data) > 0

def test_load_dl2_data_file_opt(self, real_dl2):
def test_load_dl2_data_file_opt(self, real_dl2, config_gen):
"""
Check on event_type
"""
for file in real_dl2.glob("*"):
data_s, _, _ = load_dl2_data_file(
str(file), "width>0", "software", "simple"
str(file), config_gen, "width>0", "software", "simple"
)
assert np.all(data_s["combo_type"] < 3)
assert len(data_s) > 0

def test_load_dl2_data_file_exc(self, real_dl2):
def test_load_dl2_data_file_exc(self, real_dl2, config_gen):
"""
Check on event_type exceptions
"""
Expand All @@ -780,7 +783,9 @@ def test_load_dl2_data_file_exc(self, real_dl2):
ValueError,
match=f"Unknown event type '{event_type}'.",
):
_, _, _ = load_dl2_data_file(str(file), "width>0", event_type, "simple")
_, _, _ = load_dl2_data_file(
str(file), config_gen, "width>0", event_type, "simple"
)

def test_get_dl2_mean_real(self, real_dl2):
"""
Expand Down
Loading
Loading