Skip to content

Commit e3d5e23

Browse files
Merge pull request #283 from rsagroup/milestone-v01-low-hanging-fruit
Low hanging fruit for milestone 0.1
2 parents f85bfd4 + f9d9d4b commit e3d5e23

16 files changed

+932
-564
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ src/rsatoolbox/cengine/*.html
1010
cadena_ploscb_data.pkl
1111
fmri_data
1212
demos/allendata
13+
demos/temp_rdm.png
1314

1415
# stats files
1516
stats/**/*.npz

demos/temp_rdm.png

-4.87 MB
Binary file not shown.

docs/source/temp_rdm.png

1.39 MB
Loading

src/rsatoolbox/data/base.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
"""Base class for Dataset
2+
"""
3+
from __future__ import annotations
4+
from rsatoolbox.util.descriptor_utils import check_descriptor_length_error
5+
from rsatoolbox.util.descriptor_utils import format_descriptor
6+
from rsatoolbox.util.descriptor_utils import parse_input_descriptor
7+
from rsatoolbox.util.file_io import write_dict_hdf5
8+
from rsatoolbox.util.file_io import write_dict_pkl
9+
from rsatoolbox.util.file_io import remove_file
10+
11+
12+
class DatasetBase:
13+
"""
14+
Abstract dataset class.
15+
Defines members that every class needs to have, but does not
16+
implement any interesting behavior. Inherit from this class
17+
to define specific dataset types
18+
19+
Args:
20+
measurements (numpy.ndarray): n_obs x n_channel 2d-array,
21+
descriptors (dict): descriptors (metadata)
22+
obs_descriptors (dict): observation descriptors (all
23+
are array-like with shape = (n_obs,...))
24+
channel_descriptors (dict): channel descriptors (all are
25+
array-like with shape = (n_channel,...))
26+
27+
Returns:
28+
dataset object
29+
"""
30+
31+
def __init__(self, measurements, descriptors=None,
32+
obs_descriptors=None, channel_descriptors=None,
33+
check_dims=True):
34+
if measurements.ndim != 2:
35+
raise AttributeError(
36+
"measurements must be in dimension n_obs x n_channel")
37+
self.measurements = measurements
38+
self.n_obs, self.n_channel = self.measurements.shape
39+
if check_dims:
40+
check_descriptor_length_error(obs_descriptors,
41+
"obs_descriptors",
42+
self.n_obs
43+
)
44+
check_descriptor_length_error(channel_descriptors,
45+
"channel_descriptors",
46+
self.n_channel
47+
)
48+
self.descriptors = parse_input_descriptor(descriptors)
49+
self.obs_descriptors = parse_input_descriptor(obs_descriptors)
50+
self.channel_descriptors = parse_input_descriptor(channel_descriptors)
51+
52+
def __repr__(self):
53+
"""
54+
defines string which is printed for the object
55+
"""
56+
return (f'rsatoolbox.data.{self.__class__.__name__}(\n'
57+
f'measurements = \n{self.measurements}\n'
58+
f'descriptors = \n{self.descriptors}\n'
59+
f'obs_descriptors = \n{self.obs_descriptors}\n'
60+
f'channel_descriptors = \n{self.channel_descriptors}\n'
61+
)
62+
63+
def __str__(self):
64+
"""
65+
defines the output of print
66+
"""
67+
string_desc = format_descriptor(self.descriptors)
68+
string_obs_desc = format_descriptor(self.obs_descriptors)
69+
string_channel_desc = format_descriptor(self.channel_descriptors)
70+
if self.measurements.shape[0] > 5:
71+
measurements = self.measurements[:5, :]
72+
else:
73+
measurements = self.measurements
74+
return (f'rsatoolbox.data.{self.__class__.__name__}\n'
75+
f'measurements = \n{measurements}\n...\n\n'
76+
f'descriptors: \n{string_desc}\n\n'
77+
f'obs_descriptors: \n{string_obs_desc}\n\n'
78+
f'channel_descriptors: \n{string_channel_desc}\n'
79+
)
80+
81+
def __eq__(self, other: DatasetBase) -> bool:
82+
"""Equality check, to be implemented in the specific
83+
Dataset class
84+
85+
Args:
86+
other (DatasetBase): The object to compare to.
87+
88+
Raises:
89+
NotImplementedError: This is not valid if not implemented
90+
by the specific Dataset class
91+
92+
Returns:
93+
bool: Never returns
94+
"""
95+
raise NotImplementedError()
96+
97+
def copy(self) -> DatasetBase:
98+
"""Copy Dataset
99+
To be implemented in child class
100+
101+
Raises:
102+
NotImplementedError: raised if not implemented
103+
104+
Returns:
105+
DatasetBase: Never returns
106+
"""
107+
raise NotImplementedError
108+
109+
def split_obs(self, by):
110+
""" Returns a list Datasets split by obs
111+
112+
Args:
113+
by(String): the descriptor by which the splitting is made
114+
115+
Returns:
116+
list of Datasets, splitted by the selected obs_descriptor
117+
118+
"""
119+
raise NotImplementedError(
120+
"split_obs function not implemented in used Dataset class!")
121+
122+
def split_channel(self, by):
123+
""" Returns a list Datasets split by channels
124+
125+
Args:
126+
by(String): the descriptor by which the splitting is made
127+
128+
Returns:
129+
list of Datasets, splitted by the selected channel_descriptor
130+
131+
"""
132+
raise NotImplementedError(
133+
"split_channel function not implemented in used Dataset class!")
134+
135+
def subset_obs(self, by, value):
136+
""" Returns a subsetted Dataset defined by certain obs value
137+
138+
Args:
139+
by(String): the descriptor by which the subset selection is made
140+
from obs dimension
141+
value: the value by which the subset selection is made
142+
from obs dimension
143+
144+
Returns:
145+
Dataset, with subset defined by the selected obs_descriptor
146+
147+
"""
148+
raise NotImplementedError(
149+
"subset_obs function not implemented in used Dataset class!")
150+
151+
def subset_channel(self, by, value):
152+
""" Returns a subsetted Dataset defined by certain channel value
153+
154+
Args:
155+
by(String): the descriptor by which the subset selection is made
156+
from channel dimension
157+
value: the value by which the subset selection is made
158+
from channel dimension
159+
160+
Returns:
161+
Dataset, with subset defined by the selected channel_descriptor
162+
163+
"""
164+
raise NotImplementedError(
165+
"subset_channel function not implemented in used Dataset class!")
166+
167+
def save(self, filename, file_type='hdf5', overwrite=False):
168+
""" Saves the dataset object to a file
169+
170+
Args:
171+
filename(String): path to the file
172+
[or opened file]
173+
file_type(String): Type of file to create:
174+
hdf5: hdf5 file
175+
pkl: pickle file
176+
overwrite(Boolean): overwrites file if it already exists
177+
178+
"""
179+
data_dict = self.to_dict()
180+
if overwrite:
181+
remove_file(filename)
182+
if file_type == 'hdf5':
183+
write_dict_hdf5(filename, data_dict)
184+
elif file_type == 'pkl':
185+
write_dict_pkl(filename, data_dict)
186+
187+
def to_dict(self):
188+
""" Generates a dictionary which contains the information to
189+
recreate the dataset object. Used for saving to disc
190+
191+
Returns:
192+
data_dict(dict): dictionary with dataset information
193+
194+
"""
195+
data_dict = {}
196+
data_dict['measurements'] = self.measurements
197+
data_dict['descriptors'] = self.descriptors
198+
data_dict['obs_descriptors'] = self.obs_descriptors
199+
data_dict['channel_descriptors'] = self.channel_descriptors
200+
data_dict['type'] = type(self).__name__
201+
return data_dict

0 commit comments

Comments
 (0)