diff --git a/README.md b/README.md index 5c9d1f6..955fb0e 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,11 @@ [![Documentation Stable](https://readthedocs.org/projects/asltk/badge/?version=main)](https://asltk.readthedocs.io/en/main/?badge=main) +![Website](https://img.shields.io/website?url=https%3A%2F%2Fasltk.readthedocs.io%2Fen%2Fmain%2F&up_message=asltk%20documentation&link=https%3A%2F%2Fasltk.readthedocs.io%2Fen%2Fmain%2F) [![codecov](https://codecov.io/gh/LOAMRI/asltk/graph/badge.svg?token=1W8GQ7SLU9)](https://codecov.io/gh/LOAMRI/asltk) [![CI_main](https://github.com/LOAMRI/asltk/actions/workflows/ci_main.yaml/badge.svg)](https://github.com/LOAMRI/asltk/actions/workflows/ci_main.yaml) [![CI_develop](https://github.com/LOAMRI/asltk/actions/workflows/ci_develop.yaml/badge.svg)](https://github.com/LOAMRI/asltk/actions/workflows/ci_develop.yaml) -![Python Versions](https://img.shields.io/badge/python-3.9%20|+-blue) +![Python Versions](https://img.shields.io/badge/python-3.10%20|+-blue) [![PyPI downloads](https://img.shields.io/pypi/dm/asltk?label=PyPI%20downloads)](https://pypi.org/project/asltk/) ![Contributors](https://img.shields.io/github/contributors/LOAMRI/asltk) [![GitHub issues](https://img.shields.io/github/issues-raw/LOAMRI/asltk.svg?maxAge=2592000)]() diff --git a/asltk/data/brain_atlas/__init__.py b/asltk/data/brain_atlas/__init__.py index 58c0b19..8bf3bb1 100644 --- a/asltk/data/brain_atlas/__init__.py +++ b/asltk/data/brain_atlas/__init__.py @@ -13,7 +13,7 @@ class BrainAtlas: ATLAS_JSON_PATH = os.path.join(os.path.dirname(__file__)) - def __init__(self, atlas_name: str = 'MNI2009'): + def __init__(self, atlas_name: str = 'MNI2009', resolution: str = '1mm'): """ Initializes the BrainAtlas class with a specified atlas name. If no atlas name is provided, it defaults to 'MNI2009'. @@ -21,7 +21,11 @@ def __init__(self, atlas_name: str = 'MNI2009'): Args: atlas_name (str, optional): The name of the atlas to be used. Defaults to 'MNI2009'. """ + self._check_resolution_input(resolution) + self._chosen_atlas = None + self._resolution = resolution + self.set_atlas(atlas_name) def set_atlas(self, atlas_name: str): @@ -61,6 +65,7 @@ def set_atlas(self, atlas_name: str): # Assuming the atlas_data is a dictionary, we can add the path to it atlas_data['atlas_file_location'] = path # Assuming the atlas data contains a key for T1-weighted and Label image data + atlas_data['resolution'] = self._resolution atlas_data['t1_data'] = os.path.join(path, self._collect_t1(path)) atlas_data['label_data'] = os.path.join( path, self._collect_label(path) @@ -77,6 +82,13 @@ def get_atlas(self): """ return self._chosen_atlas + def set_resolution(self, resolution: str): + self._check_resolution_input(resolution) + self._resolution = resolution + + def get_resolution(self): + return self._resolution + def get_atlas_url(self, atlas_name: str): """ Get the brain atlas URL of the chosen format in the ASLtk database. @@ -145,10 +157,13 @@ def _collect_t1(self, path: str): # pragma: no cover Returns: str: The filename of the T1-weighted image data. """ - t1_file = next((f for f in os.listdir(path) if '_t1' in f), None) + t1_file = next( + (f for f in os.listdir(path) if self._resolution + '_t1' in f), + None, + ) if t1_file is None: raise ValueError( - f"No file with '_t1' found in the atlas directory: {path}" + f"No file with '_t1_' and resolution {self._resolution} found in the atlas directory: {path}" ) return t1_file @@ -161,10 +176,20 @@ def _collect_label(self, path: str): # pragma: no cover Returns: str: The filename of the label file. """ - label_file = next((f for f in os.listdir(path) if '_label' in f), None) + label_file = next( + (f for f in os.listdir(path) if self._resolution + '_label' in f), + None, + ) if label_file is None: raise ValueError( - f"No file with '_label' found in the atlas directory: {path}" + f"No file with '_label' and resolution {self._resolution} found in the atlas directory: {path}" ) return label_file + + def _check_resolution_input(self, resolution): + valid_resolutions = ['1mm', '2mm'] + if resolution not in valid_resolutions: + raise ValueError( + f"Invalid resolution '{resolution}'. Valid options are: {valid_resolutions}" + ) diff --git a/tests/data/brain_atlas/test_brain_atlas.py b/tests/data/brain_atlas/test_brain_atlas.py index 35a1241..8b2aa0a 100644 --- a/tests/data/brain_atlas/test_brain_atlas.py +++ b/tests/data/brain_atlas/test_brain_atlas.py @@ -122,6 +122,56 @@ def test_brain_atlas_creation_with_various_names(atlas_name): assert isinstance(atlas.get_atlas(), dict) +@pytest.mark.parametrize( + 'atlas_name', + [ + 'MNI2009', + 'AAL32024', + 'HOCSA2006', + 'AAT2022', + 'AICHA2021', + 'DKA2006', + 'FCA7N2011', + 'HA2003', + 'JHA2005', + 'LGPHCC2022', + 'AAT2022', + ], +) +def test_brain_atlas_creation_with_various_names_2mm_resolution(atlas_name): + """ + Test creating BrainAtlas objects with different valid atlas names. + """ + atlas = BrainAtlas(atlas_name=atlas_name, resolution='2mm') + assert isinstance(atlas.get_atlas(), dict) + + +@pytest.mark.parametrize( + 'wrong_resolution', + [ + ('1'), + ('2'), + ('3mm'), + ('1.5mm'), + ('4mm'), + ('1x1x1'), + ('2x2x2'), + (1), + (2), + ], +) +def test_brain_atlas_constructor_raise_error_wrong_resolution( + wrong_resolution, +): + """ + Test that the BrainAtlas constructor raises an error for invalid resolution. + """ + with pytest.raises(ValueError) as e: + BrainAtlas(resolution=wrong_resolution) + + assert 'Invalid resolution' in str(e.value) + + def test_atlas_download_failure(mocker): """ Test that appropriate error is raised when atlas download fails. @@ -159,3 +209,21 @@ def test_atlas_url_raises_error_when_atlas_not_set(): # Verify the error message assert 'is not set or does not have a dataset URL' in str(e.value) + + +def test_brain_atlas_get_resolution(): + """ + Test the get_resolution method of the BrainAtlas class. + """ + atlas = BrainAtlas() + atlas.set_resolution('2mm') + assert atlas.get_resolution() == '2mm' + + +def test_brain_atlas_set_resolution(): + """ + Test the set_resolution method of the BrainAtlas class. + """ + atlas = BrainAtlas() + atlas.set_resolution('2mm') + assert atlas.get_resolution() == '2mm'