Skip to content

Commit

Permalink
Merge pull request #283 from rsagroup/milestone-v01-low-hanging-fruit
Browse files Browse the repository at this point in the history
Low hanging fruit for milestone 0.1
  • Loading branch information
JasperVanDenBosch authored Dec 16, 2022
2 parents f85bfd4 + f9d9d4b commit e3d5e23
Show file tree
Hide file tree
Showing 16 changed files with 932 additions and 564 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ src/rsatoolbox/cengine/*.html
cadena_ploscb_data.pkl
fmri_data
demos/allendata
demos/temp_rdm.png

# stats files
stats/**/*.npz
Expand Down
Binary file removed demos/temp_rdm.png
Binary file not shown.
Binary file modified docs/source/temp_rdm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
201 changes: 201 additions & 0 deletions src/rsatoolbox/data/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""Base class for Dataset
"""
from __future__ import annotations
from rsatoolbox.util.descriptor_utils import check_descriptor_length_error
from rsatoolbox.util.descriptor_utils import format_descriptor
from rsatoolbox.util.descriptor_utils import parse_input_descriptor
from rsatoolbox.util.file_io import write_dict_hdf5
from rsatoolbox.util.file_io import write_dict_pkl
from rsatoolbox.util.file_io import remove_file


class DatasetBase:
"""
Abstract dataset class.
Defines members that every class needs to have, but does not
implement any interesting behavior. Inherit from this class
to define specific dataset types
Args:
measurements (numpy.ndarray): n_obs x n_channel 2d-array,
descriptors (dict): descriptors (metadata)
obs_descriptors (dict): observation descriptors (all
are array-like with shape = (n_obs,...))
channel_descriptors (dict): channel descriptors (all are
array-like with shape = (n_channel,...))
Returns:
dataset object
"""

def __init__(self, measurements, descriptors=None,
obs_descriptors=None, channel_descriptors=None,
check_dims=True):
if measurements.ndim != 2:
raise AttributeError(
"measurements must be in dimension n_obs x n_channel")
self.measurements = measurements
self.n_obs, self.n_channel = self.measurements.shape
if check_dims:
check_descriptor_length_error(obs_descriptors,
"obs_descriptors",
self.n_obs
)
check_descriptor_length_error(channel_descriptors,
"channel_descriptors",
self.n_channel
)
self.descriptors = parse_input_descriptor(descriptors)
self.obs_descriptors = parse_input_descriptor(obs_descriptors)
self.channel_descriptors = parse_input_descriptor(channel_descriptors)

def __repr__(self):
"""
defines string which is printed for the object
"""
return (f'rsatoolbox.data.{self.__class__.__name__}(\n'
f'measurements = \n{self.measurements}\n'
f'descriptors = \n{self.descriptors}\n'
f'obs_descriptors = \n{self.obs_descriptors}\n'
f'channel_descriptors = \n{self.channel_descriptors}\n'
)

def __str__(self):
"""
defines the output of print
"""
string_desc = format_descriptor(self.descriptors)
string_obs_desc = format_descriptor(self.obs_descriptors)
string_channel_desc = format_descriptor(self.channel_descriptors)
if self.measurements.shape[0] > 5:
measurements = self.measurements[:5, :]
else:
measurements = self.measurements
return (f'rsatoolbox.data.{self.__class__.__name__}\n'
f'measurements = \n{measurements}\n...\n\n'
f'descriptors: \n{string_desc}\n\n'
f'obs_descriptors: \n{string_obs_desc}\n\n'
f'channel_descriptors: \n{string_channel_desc}\n'
)

def __eq__(self, other: DatasetBase) -> bool:
"""Equality check, to be implemented in the specific
Dataset class
Args:
other (DatasetBase): The object to compare to.
Raises:
NotImplementedError: This is not valid if not implemented
by the specific Dataset class
Returns:
bool: Never returns
"""
raise NotImplementedError()

def copy(self) -> DatasetBase:
"""Copy Dataset
To be implemented in child class
Raises:
NotImplementedError: raised if not implemented
Returns:
DatasetBase: Never returns
"""
raise NotImplementedError

def split_obs(self, by):
""" Returns a list Datasets split by obs
Args:
by(String): the descriptor by which the splitting is made
Returns:
list of Datasets, splitted by the selected obs_descriptor
"""
raise NotImplementedError(
"split_obs function not implemented in used Dataset class!")

def split_channel(self, by):
""" Returns a list Datasets split by channels
Args:
by(String): the descriptor by which the splitting is made
Returns:
list of Datasets, splitted by the selected channel_descriptor
"""
raise NotImplementedError(
"split_channel function not implemented in used Dataset class!")

def subset_obs(self, by, value):
""" Returns a subsetted Dataset defined by certain obs value
Args:
by(String): the descriptor by which the subset selection is made
from obs dimension
value: the value by which the subset selection is made
from obs dimension
Returns:
Dataset, with subset defined by the selected obs_descriptor
"""
raise NotImplementedError(
"subset_obs function not implemented in used Dataset class!")

def subset_channel(self, by, value):
""" Returns a subsetted Dataset defined by certain channel value
Args:
by(String): the descriptor by which the subset selection is made
from channel dimension
value: the value by which the subset selection is made
from channel dimension
Returns:
Dataset, with subset defined by the selected channel_descriptor
"""
raise NotImplementedError(
"subset_channel function not implemented in used Dataset class!")

def save(self, filename, file_type='hdf5', overwrite=False):
""" Saves the dataset object to a file
Args:
filename(String): path to the file
[or opened file]
file_type(String): Type of file to create:
hdf5: hdf5 file
pkl: pickle file
overwrite(Boolean): overwrites file if it already exists
"""
data_dict = self.to_dict()
if overwrite:
remove_file(filename)
if file_type == 'hdf5':
write_dict_hdf5(filename, data_dict)
elif file_type == 'pkl':
write_dict_pkl(filename, data_dict)

def to_dict(self):
""" Generates a dictionary which contains the information to
recreate the dataset object. Used for saving to disc
Returns:
data_dict(dict): dictionary with dataset information
"""
data_dict = {}
data_dict['measurements'] = self.measurements
data_dict['descriptors'] = self.descriptors
data_dict['obs_descriptors'] = self.obs_descriptors
data_dict['channel_descriptors'] = self.channel_descriptors
data_dict['type'] = type(self).__name__
return data_dict
Loading

0 comments on commit e3d5e23

Please sign in to comment.