Skip to content

Commit c182af8

Browse files
authored
Merge pull request #30 from mathysgrapotte/csvhandlerpolar
Csvhandlerpolar
2 parents 18ea499 + 1031311 commit c182af8

File tree

7 files changed

+121
-75
lines changed

7 files changed

+121
-75
lines changed

bin/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
numpy==1.26.0
44
pytorch-lightning==2.0.1
55
scikit-learn==1.3.0
6-
pandas==2.0.3
6+
polars==0.20.15

bin/src/data/csv_parser.py renamed to bin/src/data/csv.py

Lines changed: 82 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,20 @@
1111
The parser is a class that takes as input a CSV file and a experiment class that defines data types to be used, noising procedures, splitting etc.
1212
"""
1313

14-
import pandas as pd
15-
from typing import Any, Tuple
14+
import polars as pl
15+
from typing import Any, Tuple, Union
16+
from functools import partial
1617

18+
class CsvHandler:
19+
"""
20+
Class for handling CSV files. #TODO add extensive description
21+
"""
1722

18-
class CSVParser: # change to CsvHandler
23+
def __init__(self, experiment: Any, csv_path: str) -> None:
24+
self.experiment = experiment
25+
self.csv_path = csv_path
26+
27+
class CsvLoader(CsvHandler): # change to CsvHandler
1928
"""
2029
Class for parsing CSV files.
2130
@@ -24,46 +33,32 @@ class CSVParser: # change to CsvHandler
2433
Then, one can get one or many items from the data, encoded.
2534
"""
2635

27-
def __init__(self, experiment: Any, csv_path: str) -> None:
28-
self.experiment = experiment
29-
self.csv_path = csv_path
30-
self.input, self.label, self.meta = self.parse_csv_to_input_label_meta(self.csv_path)
31-
self.padding_value = self.find_padding_value(self.input)
32-
33-
def parse_csv_to_input_label_meta(self, csv_path: str) -> Tuple[dict, dict, dict]:
36+
def __init__(self, experiment: Any, csv_path: str, split: Union[int, None] = None) -> None:
37+
super().__init__(experiment, csv_path)
38+
if split is not None:
39+
# if split is present, we defined the prefered load method to be the load_csv_per_split method with default argument split
40+
prefered_load_method = partial(self.load_csv_per_split, split=split)
41+
else:
42+
prefered_load_method = self.load_all_csv
43+
self.input, self.label, self.meta = self.parse_csv_to_input_label_meta(self.csv_path, prefered_load_method)
44+
45+
def load_all_csv(self, csv_path: str) -> pl.DataFrame:
3446
"""
35-
This function reads the csv file into a dictionary,
36-
and then parses each key with the form name:category:type
37-
into three dictionaries, one for each category [input, label, meta].
38-
The keys of each new dictionary are in this form name:type.
47+
Loads the csv file into a polars dataframe.
3948
"""
40-
# read csv file into a dictionary of lists
41-
# the keys of the dictionary are the column names and the values are the column values
42-
data = pd.read_csv(csv_path, dtype=str).to_dict(orient="list")
43-
44-
# parse the dictionary into three dictionaries, one for each category [input, label, meta]
45-
input_data, label_data, meta_data = {}, {}, {}
46-
for key in data:
47-
name, category, data_type = key.split(":")
48-
if category.lower() == "input":
49-
input_data[f"{name}:{data_type}"] = data[key]
50-
elif category.lower() == "label":
51-
label_data[f"{name}:{data_type}"] = data[key]
52-
elif category.lower() == "meta":
53-
meta_data[f"{name}:{data_type}"] = data[key]
54-
else:
55-
raise ValueError(f"Unknown category {category}, category (the second element of the csv column, seperated by ':') should be input, label or meta. The specified csv column is {key}.")
56-
return input_data, label_data, meta_data
57-
58-
def find_padding_value(self, data: dict) -> int:
49+
return pl.read_csv(csv_path)
50+
51+
def load_csv_per_split(self, csv_path: str, split: int) -> pl.DataFrame:
5952
"""
60-
Find an integer that is not present in any of the lists of the data dictionary
53+
Split is the number of split to load, 0 is train, 1 is validation, 2 is test.
54+
This is accessed through the column named "split:meta:int"
6155
"""
62-
i = 0
63-
while True:
64-
if i not in [item for sublist in data.values() for item in sublist]:
65-
return i
66-
i += 1
56+
data = pl.read_csv(csv_path)
57+
# check that the selected split value is present in the column split:meta:int
58+
if split not in data["split:meta:int"].unique().to_list():
59+
raise ValueError(f"The split value {split} is not present in the column split:meta:int. The available values are {data['split:meta:int'].unique().to_list()}")
60+
61+
return data.filter(data["split:meta:int"] == split)
6762

6863
def get_and_encode(self, dictionary: dict, idx: Any) -> dict:
6964
"""
@@ -79,7 +74,7 @@ def get_and_encode(self, dictionary: dict, idx: Any) -> dict:
7974
"""
8075
output = {}
8176
for key in dictionary: # processing each column
82-
77+
8378
# get the name and data_type
8479
name = key.split(":")[0]
8580
data_type = key.split(":")[1]
@@ -97,31 +92,66 @@ def get_and_encode(self, dictionary: dict, idx: Any) -> dict:
9792

9893
# encode the data at given index
9994
# For that, it first retrieves the data object and then calls the encode_all method to encode the data
100-
101-
10295
output[name] = self.experiment.get_encoding_all(data_type)(dictionary[key][idx])
10396

104-
10597
return output
10698

107-
def get_encoded_item(self, idx: Any) -> Tuple[dict, dict, dict]:
99+
def __len__(self) -> int:
100+
"""
101+
returns the length of the first list in input, assumes that all are the same length
102+
"""
103+
return len(list(self.input.values())[0])
104+
105+
def parse_csv_to_input_label_meta(self, csv_path: str, load_method: Any) -> Tuple[dict, dict, dict]:
106+
"""
107+
This function reads the csv file into a dictionary,
108+
and then parses each key with the form name:category:type
109+
into three dictionaries, one for each category [input, label, meta].
110+
The keys of each new dictionary are in this form name:type.
111+
"""
112+
# read csv file into a dictionary of lists
113+
# the keys of the dictionary are the column names and the values are the column values
114+
data = load_method(csv_path).to_dict(as_series=False)
115+
116+
# parse the dictionary into three dictionaries, one for each category [input, label, meta]
117+
input_data, label_data, meta_data = {}, {}, {}
118+
for key in data:
119+
name, category, data_type = key.split(":")
120+
if category.lower() == "input":
121+
input_data[f"{name}:{data_type}"] = data[key]
122+
elif category.lower() == "label":
123+
label_data[f"{name}:{data_type}"] = data[key]
124+
elif category.lower() == "meta":
125+
meta_data[f"{name}:{data_type}"] = data[key]
126+
else:
127+
raise ValueError(f"Unknown category {category}, category (the second element of the csv column, seperated by ':') should be input, label or meta. The specified csv column is {key}.")
128+
return input_data, label_data, meta_data
129+
130+
def __getitem__(self, idx: Any) -> dict:
108131
"""
109132
It gets the data at a given index, and encodes the input and label, leaving meta as it is.
110133
"""
111134
x = self.get_and_encode(self.input, idx)
112135
y = self.get_and_encode(self.label, idx)
113136
return x, y, self.meta
114137

115-
def __len__(self) -> int:
138+
class CsvParser(CsvHandler):
139+
"""
140+
Class for loading
141+
"""
142+
143+
def __init__(self, experiment: Any, csv_path: str) -> None:
144+
super().__init__(experiment, csv_path)
145+
146+
def save(self, path: str) -> None:
116147
"""
117-
returns the length of the first list in input, assumes that all are the same length
148+
Saves the data to a csv file.
118149
"""
119-
return len(list(self.input.values())[0])
120-
121-
def __getitem__(self, idx: Any) -> dict:
150+
pass
151+
152+
def noise(self, data):
122153
"""
123-
get a dictionary with all the keys for the data at a given index
154+
Adds noise to the data.
124155
"""
125-
data = {**self.input, **self.label, **self.meta}
126-
return { key: data[key][idx] for key in data }
156+
pass
127157

bin/src/data/encoding/encoders.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def encode_all(self, data: Union[list, str]) -> np.array:
7777
Encodes the data, if the list is length one, call encode instead.
7878
It resturns a list with all the encoded data entries.
7979
"""
80-
# check if the data is a str, in that case it should use the encode sequence method
81-
if isinstance(data, str):
80+
# check if the data is not a list, in this case it should use the encode method
81+
if not isinstance(data, list):
8282
return [self.encode(data)]
8383
else:
8484
return self.encode_multiprocess(data)
@@ -106,8 +106,8 @@ def encode_all(self, data: list) -> list:
106106
This method takes as input a list of data points, should be mappable to a single output.
107107
"""
108108

109-
# check if data is a string, in that case it should use the encode sequence method
110-
if isinstance(data, str):
109+
# check if data is not a list, in that case it should use the encode sequence method
110+
if not isinstance(data, list):
111111
return [self.encode(data)]
112112
else:
113113
return [float(d) for d in data]

bin/src/data/handlertorch.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
from torch.utils.data import Dataset, DataLoader
88
from torch.nn.utils.rnn import pad_sequence
9-
from .csv_parser import CSVParser
9+
from .csv import CsvLoader
1010
from typing import Any, Tuple
1111

1212
class TorchDataset(Dataset):
@@ -15,7 +15,7 @@ class TorchDataset(Dataset):
1515
"""
1616
def __init__(self, csvpath : str, experiment : Any) -> None:
1717
self.csvpath = csvpath
18-
self.parser = CSVParser(experiment, csvpath)
18+
self.parser = CsvLoader(experiment, csvpath)
1919

2020
def convert_list_of_numpy_arrays_to_tensor(self, data: list) -> Tuple[torch.Tensor, torch.Tensor]:
2121
"""
@@ -38,13 +38,13 @@ def convert_list_of_numpy_arrays_to_tensor(self, data: list) -> Tuple[torch.Tens
3838
data = [torch.from_numpy(d) for d in data] # convert the np arrays to tensors
3939

4040
# pad sequences
41-
padded_data = pad_sequence(data, batch_first=True, padding_value=self.parser.padding_value)
41+
padded_data = pad_sequence(data, batch_first=True, padding_value=42)
4242

4343
# create a mask of the same shape as the padded data
4444
mask = torch.zeros_like(padded_data)
4545

4646
# mask should have ones everywhere the data is not padded (so values are not 42)
47-
mask[padded_data != self.parser.padding_value] = 1
47+
mask[padded_data != 42] = 1
4848

4949
return padded_data, mask
5050

@@ -65,7 +65,7 @@ def __len__(self) -> int:
6565
return len(self.parser)
6666

6767
def __getitem__(self, idx: int) -> Tuple[dict, dict, dict]:
68-
x, y, meta = self.parser.get_encoded_item(idx)
68+
x, y, meta = self.parser[idx]
6969
# convert the content in the x and y directories to torch tensors
7070
x, x_mask = self.convert_dict_to_tensor(x)
7171
y, y_mask = self.convert_dict_to_tensor(y)

bin/tests/test_csv_parser.py renamed to bin/tests/test_csv.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
1-
import numpy as np
2-
import numpy.testing as npt
31
import unittest
42
import os
5-
from bin.src.data.csv_parser import CSVParser
3+
from bin.src.data.csv import CsvLoader
64
from bin.src.data.experiments import DnaToFloatExperiment
75

8-
class TestDnaToFloatCsvParser(unittest.TestCase):
6+
class TestDnaToFloatCsvLoader(unittest.TestCase):
97

108
def setUp(self):
11-
self.csv_parser = CSVParser(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test.csv"))
9+
self.csv_loader = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test.csv"))
10+
self.csv_loader_split = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test_with_split.csv"), split=0)
1211

1312
def test_get_encoded_item_unique(self):
1413
"""
15-
It tests that the csv_parser.get_encoded_item works well when getting one item.
14+
It tests that the csv_loader.get_encoded_item works well when getting one item.
1615
The following test is performed on the item at idx=0.
1716
"""
1817
# get the encoded item from the csv file at idx 0
19-
encoded_item = self.csv_parser.get_encoded_item(0)
18+
encoded_item = self.csv_loader[0]
2019

2120
# test that the encoded item is a tuple of three dictionaries [input, label, meta]
2221
self.assertEqual(len(encoded_item), 3)
@@ -41,12 +40,12 @@ def test_get_encoded_item_unique(self):
4140

4241
def test_get_encoded_item_multiple(self):
4342
"""
44-
It tests that the csv_parser.get_encoded_item works well when getting multiple items using slice.
43+
It tests that the csv_loader.get_encoded_item works well when getting multiple items using slice.
4544
The following test is performed on the item at idx=0 and idx=1.
4645
"""
4746

4847
# get the encoded items from the csv file at idx 0 and 1
49-
encoded_item = self.csv_parser.get_encoded_item(slice(0, 2))
48+
encoded_item = self.csv_loader[slice(0, 2)]
5049

5150
# test that the encoded item is a tuple of three dictionaries [input, label, meta]
5251
self.assertEqual(len(encoded_item), 3)
@@ -70,5 +69,20 @@ def test_get_encoded_item_multiple(self):
7069
self.assertEqual(len(encoded_item[1][key]), 2)
7170

7271
def test_len(self):
73-
self.assertEqual(len(self.csv_parser), 2)
72+
self.assertEqual(len(self.csv_loader), 2)
73+
74+
def test_load_with_split(self):
75+
# try loading with different split values, should run with 0,1 and 2 and raise an error for other values
76+
self.csv_loader_split = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test_with_split.csv"), split=0)
77+
# self.csv_loader_split.input['hello'] should have only one value
78+
self.assertEqual(len(self.csv_loader_split.input['hello:dna']), 1)
79+
# check that self.csv_loader_split.meta has only one value in the ['split:int'] column which is 0
80+
self.assertEqual(len(self.csv_loader_split.meta['split:int']), 1)
81+
self.assertEqual(self.csv_loader_split.meta['split:int'][0], 0)
82+
self.csv_loader_split = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test_with_split.csv"), split=1)
83+
self.csv_loader_split = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test_with_split.csv"), split=2)
84+
with self.assertRaises(ValueError): # should raise an error
85+
self.csv_loader_split = CsvLoader(DnaToFloatExperiment(), os.path.abspath("bin/tests/test_data/test_with_split.csv"), split=3)
86+
87+
7488

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
hello:input:dna,hola:label:float,split:meta:int
2+
ACTGACTGATCGATGC,12,0
3+
ACTGACTGATCGATGC,12,1
4+
ACTGACTGATCGATGC,12,2

bin/tests/test_handlertorch.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import numpy as np
2-
import numpy.testing as npt
32
import unittest
43
import os
54
import torch
65
from bin.src.data.handlertorch import TorchDataset
76
from bin.src.data.experiments import DnaToFloatExperiment
8-
from bin.src.data.csv_parser import CSVParser
97

108
# initialize unittest class
119
class TestDnaToFloatTorchDataset(unittest.TestCase):
@@ -32,7 +30,7 @@ def test_convert_dict_to_tensor_same_lengths(self):
3230
self.assertIsNone(mask_dict["hola"])
3331

3432

35-
input_data = self.torchdataset_same_length.parser.get_encoded_item(slice(0, 2))
33+
input_data = self.torchdataset_same_length.parser[slice(0, 2)]
3634
output_dict, mask_dict = self.torchdataset_same_length.convert_dict_to_tensor(input_data[0])
3735

3836
def test_get_item_same_lenghts(self):

0 commit comments

Comments
 (0)