Skip to content

Commit fcc320e

Browse files
Merge pull request #25 from mathysgrapotte/add-protein-type
Add protein type
2 parents 87b77cd + e20d70a commit fcc320e

File tree

6 files changed

+124
-28
lines changed

6 files changed

+124
-28
lines changed

bin/src/data/csv_parser.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,20 @@ def get_and_encode(self, dictionary: dict, idx: Any) -> dict:
8484
name = key.split(":")[0]
8585
data_type = key.split(":")[1]
8686

87+
# get the data at the given index
88+
# if the data is not a list, it is converted to a list
89+
# otherwise it breaks Float().encode_all(data) because it expects a list
90+
data = dictionary[key][idx]
91+
if not isinstance(data, list):
92+
data = [data]
93+
8794
# check if 'data_type' is in the experiment class attributes
8895
if not hasattr(self.experiment, data_type.lower()):
8996
raise ValueError(f"The data type {data_type} is not in the experiment class attributes. the column name is {key}, the available attributes are {self.experiment.__dict__}")
9097

9198
# encode the data at given index
9299
# For that, it first retrieves the data object and then calls the encode_all method to encode the data
93-
# BUG when there is only one element in the list, then we don't get one list anymore, but only the element. And this creates error at Float.encode_all() since here [np.array(float(d)) for d in data] data is only a string and not a list of strings.
94-
output[name] = self.experiment.__getattribute__(data_type.lower()).encode_all(dictionary[key][idx])
100+
output[name] = self.experiment.__getattribute__(data_type.lower()).encode_all(data)
95101

96102
return output
97103

bin/src/data/data_types/data_types.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,36 +55,83 @@ def encode(self, data: str, encoder: Literal['one_hot'] = 'one_hot') -> Any: #TO
5555
else:
5656
raise ValueError(f"Unknown encoder {encoder}")
5757

58-
5958
def encode_all(self, data: list, encoder: Literal['one_hot'] = 'one_hot') -> list[np.array]:
6059
if encoder == 'one_hot':
6160
return self.one_hot_encode_all(data)
6261
else:
6362
raise ValueError(f"Unknown encoder {encoder}")
6463

64+
def add_noise_uniform_text_masker(self, data: str, seed: float = None, **noise_params) -> str:
65+
"""
66+
Adds noise to the data of a single input.
67+
"""
68+
# get the probability param from noise_params, default value is set to 0.1
69+
probability = noise_params.get("probability", 0.1)
70+
return self.uniform_text_masker.add_noise(data, probability=probability, mask='N', seed=seed)
71+
72+
def add_noise_uniform_text_masker_all_inputs(self, data: list, seed: float = None, **noise_params) -> list:
73+
"""
74+
Adds noise to the data of multiple inputs.
75+
"""
76+
# get the probability param from noise_params, default value is set to 0.1
77+
probability = noise_params.get("probability", 0.1)
78+
return self.uniform_text_masker.add_noise_multiprocess(data, probability=probability, mask='N', seed=seed)
6579

80+
81+
class Prot(AbstractType):
82+
"""
83+
class for dealing with protein data
84+
"""
85+
86+
def __init__(self, **parameters) -> None:
87+
self.one_hot_encoder = encoders.TextOneHotEncoder(alphabet=parameters.get("one_hot_encoder_alphabet", "acdefghiklmnpqrstvwy"))
88+
self.uniform_text_masker = noise_generators.UniformTextMasker()
89+
90+
def one_hot_encode(self, data: str) -> np.array:
91+
"""
92+
Encodes the data of a single input.
93+
"""
94+
return self.one_hot_encoder.encode(data)
95+
96+
def one_hot_encode_all(self, data: list) -> list:
97+
"""
98+
Encodes the data of multiple inputs.
99+
"""
100+
return self.one_hot_encoder.encode_all(data)
101+
102+
def encode(self, data: str, encoder: Literal['one_hot'] = 'one_hot') -> Any: #TODO call from get attribute instead of using if else
103+
if encoder == 'one_hot':
104+
return self.one_hot_encode(data)
105+
else:
106+
raise ValueError(f"Unknown encoder {encoder}")
107+
108+
def encode_all(self, data: list, encoder: Literal['one_hot'] = 'one_hot') -> list[np.array]:
109+
if encoder == 'one_hot':
110+
return self.one_hot_encode_all(data)
111+
else:
112+
raise ValueError(f"Unknown encoder {encoder}")
113+
66114
def add_noise_uniform_text_masker(self, data: str, seed: float = None, **noise_params) -> str:
67115
"""
68116
Adds noise to the data of a single input.
69117
"""
70118
# get the probability param from noise_params, default value is set to 0.1
71119
probability = noise_params.get("probability", 0.1)
72-
return self.uniform_text_masker.add_noise(data, probability=probability, seed=seed)
120+
return self.uniform_text_masker.add_noise(data, probability=probability, mask='X', seed=seed)
73121

74122
def add_noise_uniform_text_masker_all_inputs(self, data: list, seed: float = None, **noise_params) -> list:
75123
"""
76124
Adds noise to the data of multiple inputs.
77125
"""
78126
# get the probability param from noise_params, default value is set to 0.1
79127
probability = noise_params.get("probability", 0.1)
80-
return self.uniform_text_masker.add_noise_multiprocess(data, probability=probability, seed=seed)
128+
return self.uniform_text_masker.add_noise_multiprocess(data, probability=probability, mask='X', seed=seed)
81129

82130

83131
class Float():
84132
"""
85133
class for dealing with float data
86134
"""
87-
88135
def __init__(self) -> None:
89136
self.gaussian_noise = noise_generators.GaussianNoise()
90137

@@ -105,4 +152,4 @@ def encode(self, data: Any) -> float:
105152

106153
def encode_all(self, data: list) -> list[np.array]:
107154
return [np.array(float(d)) for d in data]
108-
155+

bin/src/data/data_types/encoding/encoders.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def encode_multiprocess(self, data: list) -> list:
4848
class TextOneHotEncoder(AbstractEncoder):
4949
"""
5050
One hot encoder for text data.
51+
52+
NOTE that it will onehot encode based on the alphabet.
53+
If there is any character not included in the alphabet, that character will be presented by a vector of zeros.
5154
"""
5255

5356
def __init__(self, alphabet: str = "acgt") -> None:
@@ -57,6 +60,7 @@ def __init__(self, alphabet: str = "acgt") -> None:
5760
def _sequence_to_array(self, sequence: str) -> np.array:
5861
"""
5962
This function transforms the given sequence to an array.
63+
eg. 'abcd' -> array(['a'],['b'],['c'],['d'])
6064
"""
6165
sequence_lower_case = sequence.lower()
6266
sequence_array = np.array(list(sequence_lower_case))
@@ -71,6 +75,7 @@ def encode(self, data: str) -> np.array:
7175
def encode_all(self, data: Union[list, str]) -> np.array:
7276
"""
7377
Encodes the data, if the list is length one, call encode instead.
78+
It resturns a list with all the encoded data entries.
7479
"""
7580
# check if the data is a str, in that case it should use the encode sequence method
7681
if isinstance(data, str):

bin/src/data/data_types/noise/noise_generators.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,38 +37,33 @@ def add_noise_multiprocess(self, data: list, seed: float = None) -> list:
3737

3838
class UniformTextMasker(AbstractNoiseGenerator):
3939
"""
40-
This noise generators replace characters with 'N' with a given probability.
40+
This noise generators replace characters with a masking character with a given probability.
4141
"""
4242

43-
44-
def add_noise(self, data: str, probability: float = 0.1, seed: float = None) -> str:
43+
def add_noise(self, data: str, probability: float = 0.1, mask='N', seed: float = None) -> str:
4544
"""
4645
Adds noise to the data.
4746
"""
48-
4947
np.random.seed(seed)
50-
return ''.join([c if np.random.rand() > probability else 'N' for c in data])
48+
return ''.join([c if np.random.rand() > probability else mask for c in data])
5149

52-
def add_noise_multiprocess(self, data: list, probability: float = 0.1, seed: float = None) -> list:
50+
def add_noise_multiprocess(self, data: list, probability: float = 0.1, mask='N', seed: float = None) -> list:
5351
"""
5452
Adds noise to the data using multiprocessing.
5553
"""
56-
5754
with mp.Pool(mp.cpu_count()) as pool:
58-
function_specific_input = [(item, probability, seed) for item in data]
55+
function_specific_input = [(item, probability, mask, seed) for item in data]
5956
return pool.starmap(self.add_noise, function_specific_input)
6057

6158
class GaussianNoise(AbstractNoiseGenerator):
6259
"""
6360
This noise generator adds gaussian noise to float values
6461
"""
6562

66-
6763
def add_noise(self, data: float, mean: float = 0, std: float= 0, seed: float = None) -> float:
6864
"""
6965
Adds noise to a single point of data.
7066
"""
71-
7267
np.random.seed(seed)
7368
return data + np.random.normal(mean, std)
7469

@@ -77,6 +72,5 @@ def add_noise_multiprocess(self, data: list, mean: float = 0, std: float = 0, se
7772
Adds noise to the data using np arrays
7873
# TODO return a np array to gain performance.
7974
"""
80-
8175
np.random.seed(seed)
8276
return list(np.array(data) + np.random.normal(mean, std, len(data)))

bin/tests/test_data/test.csv

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
hello:input:dna,hola:label:float
2-
ACTGACTGATCGATGC,5
3-
ACTGACTGATCGATGC,5
2+
ACTGACTGATCGATGC,12
3+
ACTGACTGATCGATGC,12

bin/tests/test_data_types.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,69 @@
11
import numpy as np
22
import numpy.testing as npt
33
import unittest
4-
from bin.src.data.data_types.data_types import Dna
4+
from bin.src.data.data_types.data_types import Dna, Prot
55

66
class TestDna(unittest.TestCase):
77

88
def setUp(self):
99
self.dna = Dna()
1010

11-
# test if the encode_all method runs with default arguments
1211
def test_encode_all(self):
13-
# Test encoding a valid list of sequences
12+
"""
13+
Test if the encode_all method runs with default arguments
14+
"""
15+
# encode a list of sequences
1416
encoded_data_list = self.dna.encode_all(["ACGT", "AAA", "tt", "Bubba"])
17+
18+
# check that the encoding returns a list
1519
self.assertIsInstance(encoded_data_list, list)
16-
# check if the length of the list is 4
17-
self.assertEqual(len(encoded_data_list), 4)
20+
21+
# check if the arrays have the correct shape
22+
self.assertEqual(encoded_data_list[0].shape, (4, 4))
23+
24+
# check we get the correct encoded arrays - first sequence
1825
correct_output = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
1926
npt.assert_array_equal(encoded_data_list[0], correct_output)
2027

21-
# test if the encode_all method returns an error when the specified encoder is not within the list of possible encoders
2228
def test_encode_all_error(self):
23-
# Test encoding a valid list of sequences
29+
"""
30+
Test if the encode_all method returns an error when the specified encoder is not within the list of possible encoders
31+
"""
32+
with self.assertRaises(ValueError):
33+
self.dna.encode_all(["ACGT", "AAA", "tt", "Bubba"], encoder="not_a_valid_encoder")
34+
35+
36+
class TestProt(unittest.TestCase):
37+
38+
def setUp(self):
39+
self.prot = Prot()
40+
41+
def test_encode_all(self):
42+
"""
43+
Test if the encode_all method runs with default arguments
44+
acdefghiklmnpqrstvwy
45+
"""
46+
# encode a list of sequences
47+
encoded_data_list = self.prot.encode_all(["ACDE", "FFF", "gg", "uuu"])
48+
49+
# check that the encoding returns a list
50+
self.assertIsInstance(encoded_data_list, list)
51+
52+
# check if the arrays have the correct shape
53+
self.assertEqual(encoded_data_list[0].shape, (4, 20))
54+
55+
# check we get the correct encoded array - first sequence
56+
correct_output = np.array([
57+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
58+
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
59+
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
60+
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
61+
])
62+
npt.assert_array_equal(encoded_data_list[0], correct_output)
63+
64+
def test_encode_all_error(self):
65+
"""
66+
Test if the encode_all method returns an error when the specified encoder is not within the list of possible encoders
67+
"""
2468
with self.assertRaises(ValueError):
25-
self.dna.encode_all(["ACGT", "AAA", "tt", "Bubba"], encoder="not_a_valid_encoder")
69+
self.prot.encode_all(["ACDE", "FFF", "gg", "uuu"], encoder="not_a_valid_encoder")

0 commit comments

Comments
 (0)