Skip to content

Commit

Permalink
Add inverse transforms.
Browse files Browse the repository at this point in the history
  • Loading branch information
tsalo committed Jan 28, 2025
1 parent 1e5acc6 commit 53ae3b1
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 7 deletions.
5 changes: 5 additions & 0 deletions src/fmripost_template/data/io_spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@
"desc": null,
"suffix": "xfm",
"extension": ".h5"
},
"all_transforms": {
"mode": "image",
"suffix": "xfm",
"extension": [".h5", ".txt"]
}
}
},
Expand Down
54 changes: 47 additions & 7 deletions src/fmripost_template/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,37 +132,77 @@ def find_shortest_path(space_pairs, start, end):


def get_transforms(source, target, local_transforms=None):
"""Get the transforms required to go from source to target space."""
"""Get the transforms required to go from source to target space.
Parameters
----------
source : str
The source space.
target : str
The target space.
local_transforms : list of str, optional
List of local transforms to consider.
Returns
-------
selected_transforms : list of str
List of selected transforms to go from source to target space.
selected_inversions : list of bool
List of booleans indicating whether the corresponding transform should be inverted.
Raises
------
ValueError
If no chain of transforms can link the source and target spaces.
"""
import templateflow.api as tflow
from bids.layout import Entity, parse_file_entities

query = [
Entity('template', 'tpl-([a-zA-Z0-9]+)'),
Entity('from', 'from-([a-zA-Z0-9]+)'),
Entity('template', 'tpl-([a-zA-Z0-9+]+)'),
Entity('from', 'from-([a-zA-Z0-9+]+)'),
Entity('to', 'to-([a-zA-Z0-9+]+)'),
]

all_transforms = local_transforms or []

templates = tflow.get_templates()
tfl_transforms = []
for template in templates:
template_transforms = tflow.get(template, suffix='xfm', extension='h5')
if not isinstance(template_transforms, list):
template_transforms = [template_transforms]
all_transforms += template_transforms
tfl_transforms += template_transforms

all_transforms += tfl_transforms
links = []
for transform in all_transforms:
entities = parse_file_entities(transform, entities=query)
link = (entities['from'], entities['template'])
if 'template' in entities:
link = (entities['from'], entities['template'])
else:
link = (entities['from'], entities['to'])
links.append(link)

inversions = [False] * len(all_transforms)

# Add inverses of all templateflow transforms (local transforms might not be invertible)
for transform in tfl_transforms:
entities = parse_file_entities(transform, entities=query)
if 'template' in entities:
links.append((entities['template'], entities['from']))
else:
links.append((entities['to'], entities['from']))
inversions.append(True)

path = None
try:
path = find_shortest_path(links, source, target)
print('Shortest path:', path)
except ValueError as e:
print(e)
raise ValueError(f'Failed to find a path from {source} to {target}') from e

selected_transforms = [all_transforms[i] for i in path]
selected_inversions = [inversions[i] for i in path]

return selected_transforms
return selected_transforms, selected_inversions

0 comments on commit 53ae3b1

Please sign in to comment.