Skip to content

Commit

Permalink
adding tests for the loading with split method in csv_loader
Browse files Browse the repository at this point in the history
  • Loading branch information
mathysgrapotte committed Mar 13, 2024
1 parent bdca5f4 commit 2d07300
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
8 changes: 4 additions & 4 deletions bin/src/data/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class CsvLoader(CsvHandler): # change to CsvHandler

def __init__(self, experiment: Any, csv_path: str, split: Union[int, None] = None) -> None:
super().__init__(experiment, csv_path)
if split:
if split is not None:
# if split is present, we defined the prefered load method to be the load_csv_per_split method with default argument split
prefered_load_method = partial(self.load_csv_per_split, split=split)
else:
Expand All @@ -55,10 +55,10 @@ def load_csv_per_split(self, csv_path: str, split: int) -> pl.DataFrame:
"""
data = pl.read_csv(csv_path)
# check that the selected split value is present in the column split:meta:int
if split not in data.column("split:meta:int").unique().to_list():
raise ValueError(f"The split value {split} is not present in the column split:meta:int. The available values are {data.column('split:meta:int').unique().to_list()}")
if split not in data["split:meta:int"].unique().to_list():
raise ValueError(f"The split value {split} is not present in the column split:meta:int. The available values are {data['split:meta:int'].unique().to_list()}")

return data.filter(data.column("split:meta:int") == split)
return data.filter(data["split:meta:int"] == split)

def get_and_encode(self, dictionary: dict, idx: Any) -> dict:
"""
Expand Down
13 changes: 13 additions & 0 deletions bin/tests/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class TestDnaToFloatCsvLoader(unittest.TestCase):

def setUp(self):
self.csv_loader = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test.csv"))
self.csv_loader_split = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test_with_split.csv"), split=0)

def test_get_encoded_item_unique(self):
"""
Expand Down Expand Up @@ -69,4 +70,16 @@ def test_get_encoded_item_multiple(self):

def test_len(self):
self.assertEqual(len(self.csv_loader), 2)

def test_load_with_split(self):
# try loading with different split values, should run with 0,1 and 2 and raise an error for other values
self.csv_loader_split = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test_with_split.csv"), split=0)
# self.csv_loader_split.input['hello'] should have only one value
self.assertEqual(len(self.csv_loader_split.input['hello:dna']), 1)
self.csv_loader_split = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test_with_split.csv"), split=1)
self.csv_loader_split = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test_with_split.csv"), split=2)
with self.assertRaises(ValueError): # should raise an error
self.csv_loader_split = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test_with_split.csv"), split=3)



0 comments on commit 2d07300

Please sign in to comment.