From 2d0730064f906a1783452a8b85cde374229ba817 Mon Sep 17 00:00:00 2001 From: mgrapotte Date: Wed, 13 Mar 2024 10:05:24 +0100 Subject: [PATCH] adding tests for the loading with split method in csv_loader --- bin/src/data/csv.py | 8 ++++---- bin/tests/test_csv.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/bin/src/data/csv.py b/bin/src/data/csv.py index edf63af..6469199 100644 --- a/bin/src/data/csv.py +++ b/bin/src/data/csv.py @@ -35,7 +35,7 @@ class CsvLoader(CsvHandler): # change to CsvHandler def __init__(self, experiment: Any, csv_path: str, split: Union[int, None] = None) -> None: super().__init__(experiment, csv_path) - if split: + if split is not None: # if split is present, we defined the prefered load method to be the load_csv_per_split method with default argument split prefered_load_method = partial(self.load_csv_per_split, split=split) else: @@ -55,10 +55,10 @@ def load_csv_per_split(self, csv_path: str, split: int) -> pl.DataFrame: """ data = pl.read_csv(csv_path) # check that the selected split value is present in the column split:meta:int - if split not in data.column("split:meta:int").unique().to_list(): - raise ValueError(f"The split value {split} is not present in the column split:meta:int. The available values are {data.column('split:meta:int').unique().to_list()}") + if split not in data["split:meta:int"].unique().to_list(): + raise ValueError(f"The split value {split} is not present in the column split:meta:int. The available values are {data['split:meta:int'].unique().to_list()}") - return data.filter(data.column("split:meta:int") == split) + return data.filter(data["split:meta:int"] == split) def get_and_encode(self, dictionary: dict, idx: Any) -> dict: """ diff --git a/bin/tests/test_csv.py b/bin/tests/test_csv.py index cd3edba..011ecf1 100644 --- a/bin/tests/test_csv.py +++ b/bin/tests/test_csv.py @@ -7,6 +7,7 @@ class TestDnaToFloatCsvLoader(unittest.TestCase): def setUp(self): self.csv_loader = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test.csv")) + self.csv_loader_split = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test_with_split.csv"), split=0) def test_get_encoded_item_unique(self): """ @@ -69,4 +70,16 @@ def test_get_encoded_item_multiple(self): def test_len(self): self.assertEqual(len(self.csv_loader), 2) + + def test_load_with_split(self): + # try loading with different split values, should run with 0,1 and 2 and raise an error for other values + self.csv_loader_split = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test_with_split.csv"), split=0) + # self.csv_loader_split.input['hello'] should have only one value + self.assertEqual(len(self.csv_loader_split.input['hello:dna']), 1) + self.csv_loader_split = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test_with_split.csv"), split=1) + self.csv_loader_split = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test_with_split.csv"), split=2) + with self.assertRaises(ValueError): # should raise an error + self.csv_loader_split = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test_with_split.csv"), split=3) + + \ No newline at end of file