diff --git a/src/fmripost_template/data/io_spec.json b/src/fmripost_template/data/io_spec.json index 3123088..db092c4 100644 --- a/src/fmripost_template/data/io_spec.json +++ b/src/fmripost_template/data/io_spec.json @@ -148,6 +148,11 @@ "desc": null, "suffix": "xfm", "extension": ".h5" + }, + "all_transforms": { + "mode": "image", + "suffix": "xfm", + "extension": [".h5", ".txt"] } } }, diff --git a/src/fmripost_template/utils/utils.py b/src/fmripost_template/utils/utils.py index a705f56..032bf15 100644 --- a/src/fmripost_template/utils/utils.py +++ b/src/fmripost_template/utils/utils.py @@ -132,37 +132,77 @@ def find_shortest_path(space_pairs, start, end): def get_transforms(source, target, local_transforms=None): - """Get the transforms required to go from source to target space.""" + """Get the transforms required to go from source to target space. + + Parameters + ---------- + source : str + The source space. + target : str + The target space. + local_transforms : list of str, optional + List of local transforms to consider. + + Returns + ------- + selected_transforms : list of str + List of selected transforms to go from source to target space. + selected_inversions : list of bool + List of booleans indicating whether the corresponding transform should be inverted. + + Raises + ------ + ValueError + If no chain of transforms can link the source and target spaces. + """ import templateflow.api as tflow from bids.layout import Entity, parse_file_entities query = [ - Entity('template', 'tpl-([a-zA-Z0-9]+)'), - Entity('from', 'from-([a-zA-Z0-9]+)'), + Entity('template', 'tpl-([a-zA-Z0-9+]+)'), + Entity('from', 'from-([a-zA-Z0-9+]+)'), + Entity('to', 'to-([a-zA-Z0-9+]+)'), ] all_transforms = local_transforms or [] templates = tflow.get_templates() + tfl_transforms = [] for template in templates: template_transforms = tflow.get(template, suffix='xfm', extension='h5') if not isinstance(template_transforms, list): template_transforms = [template_transforms] - all_transforms += template_transforms + tfl_transforms += template_transforms + all_transforms += tfl_transforms links = [] for transform in all_transforms: entities = parse_file_entities(transform, entities=query) - link = (entities['from'], entities['template']) + if 'template' in entities: + link = (entities['from'], entities['template']) + else: + link = (entities['from'], entities['to']) links.append(link) + inversions = [False] * len(all_transforms) + + # Add inverses of all templateflow transforms (local transforms might not be invertible) + for transform in tfl_transforms: + entities = parse_file_entities(transform, entities=query) + if 'template' in entities: + links.append((entities['template'], entities['from'])) + else: + links.append((entities['to'], entities['from'])) + inversions.append(True) + path = None try: path = find_shortest_path(links, source, target) print('Shortest path:', path) except ValueError as e: - print(e) + raise ValueError(f'Failed to find a path from {source} to {target}') from e selected_transforms = [all_transforms[i] for i in path] + selected_inversions = [inversions[i] for i in path] - return selected_transforms + return selected_transforms, selected_inversions