@@ -132,37 +132,77 @@ def find_shortest_path(space_pairs, start, end):
132
132
133
133
134
134
def get_transforms (source , target , local_transforms = None ):
135
- """Get the transforms required to go from source to target space."""
135
+ """Get the transforms required to go from source to target space.
136
+
137
+ Parameters
138
+ ----------
139
+ source : str
140
+ The source space.
141
+ target : str
142
+ The target space.
143
+ local_transforms : list of str, optional
144
+ List of local transforms to consider.
145
+
146
+ Returns
147
+ -------
148
+ selected_transforms : list of str
149
+ List of selected transforms to go from source to target space.
150
+ selected_inversions : list of bool
151
+ List of booleans indicating whether the corresponding transform should be inverted.
152
+
153
+ Raises
154
+ ------
155
+ ValueError
156
+ If no chain of transforms can link the source and target spaces.
157
+ """
136
158
import templateflow .api as tflow
137
159
from bids .layout import Entity , parse_file_entities
138
160
139
161
query = [
140
- Entity ('template' , 'tpl-([a-zA-Z0-9]+)' ),
141
- Entity ('from' , 'from-([a-zA-Z0-9]+)' ),
162
+ Entity ('template' , 'tpl-([a-zA-Z0-9+]+)' ),
163
+ Entity ('from' , 'from-([a-zA-Z0-9+]+)' ),
164
+ Entity ('to' , 'to-([a-zA-Z0-9+]+)' ),
142
165
]
143
166
144
167
all_transforms = local_transforms or []
145
168
146
169
templates = tflow .get_templates ()
170
+ tfl_transforms = []
147
171
for template in templates :
148
172
template_transforms = tflow .get (template , suffix = 'xfm' , extension = 'h5' )
149
173
if not isinstance (template_transforms , list ):
150
174
template_transforms = [template_transforms ]
151
- all_transforms += template_transforms
175
+ tfl_transforms += template_transforms
152
176
177
+ all_transforms += tfl_transforms
153
178
links = []
154
179
for transform in all_transforms :
155
180
entities = parse_file_entities (transform , entities = query )
156
- link = (entities ['from' ], entities ['template' ])
181
+ if 'template' in entities :
182
+ link = (entities ['from' ], entities ['template' ])
183
+ else :
184
+ link = (entities ['from' ], entities ['to' ])
157
185
links .append (link )
158
186
187
+ inversions = [False ] * len (all_transforms )
188
+
189
+ # Add inverses of all templateflow transforms (local transforms might not be invertible)
190
+ for transform in tfl_transforms :
191
+ entities = parse_file_entities (transform , entities = query )
192
+ if 'template' in entities :
193
+ links .append ((entities ['template' ], entities ['from' ]))
194
+ else :
195
+ links .append ((entities ['to' ], entities ['from' ]))
196
+ inversions .append (True )
197
+
159
198
path = None
160
199
try :
161
200
path = find_shortest_path (links , source , target )
162
201
print ('Shortest path:' , path )
163
202
except ValueError as e :
164
- print ( e )
203
+ raise ValueError ( f'Failed to find a path from { source } to { target } ' ) from e
165
204
166
205
selected_transforms = [all_transforms [i ] for i in path ]
206
+ selected_inversions = [inversions [i ] for i in path ]
167
207
168
- return selected_transforms
208
+ return selected_transforms , selected_inversions
0 commit comments