Skip to content

Commit 5aebad3

Browse files
committed
return_map back to out_map for _feature helpers
1 parent acfd85a commit 5aebad3

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

timm/models/_features.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(
100100
self,
101101
hooks: Sequence[str],
102102
named_modules: dict,
103-
return_map: Sequence[Union[int, str]] = None,
103+
out_map: Sequence[Union[int, str]] = None,
104104
default_hook_type: str = 'forward',
105105
):
106106
# setup feature hooks
@@ -109,7 +109,7 @@ def __init__(
109109
for i, h in enumerate(hooks):
110110
hook_name = h['module']
111111
m = modules[hook_name]
112-
hook_id = return_map[i] if return_map else hook_name
112+
hook_id = out_map[i] if out_map else hook_name
113113
hook_fn = partial(self._collect_output_hook, hook_id)
114114
hook_type = h.get('hook_type', default_hook_type)
115115
if hook_type == 'forward_pre':
@@ -155,11 +155,11 @@ def _get_feature_info(net, out_indices):
155155
assert False, "Provided feature_info is not valid"
156156

157157

158-
def _get_return_layers(feature_info, return_map):
158+
def _get_return_layers(feature_info, out_map):
159159
module_names = feature_info.module_name()
160160
return_layers = {}
161161
for i, name in enumerate(module_names):
162-
return_layers[name] = return_map[i] if return_map is not None else feature_info.out_indices[i]
162+
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
163163
return return_layers
164164

165165

@@ -182,7 +182,7 @@ def __init__(
182182
self,
183183
model: nn.Module,
184184
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
185-
return_map: Sequence[Union[int, str]] = None,
185+
out_map: Sequence[Union[int, str]] = None,
186186
output_fmt: str = 'NCHW',
187187
feature_concat: bool = False,
188188
flatten_sequential: bool = False,
@@ -191,7 +191,7 @@ def __init__(
191191
Args:
192192
model: Model from which to extract features.
193193
out_indices: Output indices of the model features to extract.
194-
return_map: Return id mapping for each output index, otherwise str(index) is used.
194+
out_map: Return id mapping for each output index, otherwise str(index) is used.
195195
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
196196
first element e.g. `x[0]`
197197
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
@@ -203,7 +203,7 @@ def __init__(
203203
self.grad_checkpointing = False
204204
self.return_layers = {}
205205

206-
return_layers = _get_return_layers(self.feature_info, return_map)
206+
return_layers = _get_return_layers(self.feature_info, out_map)
207207
modules = _module_list(model, flatten_sequential=flatten_sequential)
208208
remaining = set(return_layers.keys())
209209
layers = OrderedDict()
@@ -298,7 +298,7 @@ def __init__(
298298
self,
299299
model: nn.Module,
300300
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
301-
return_map: Sequence[Union[int, str]] = None,
301+
out_map: Sequence[Union[int, str]] = None,
302302
return_dict: bool = False,
303303
output_fmt: str = 'NCHW',
304304
no_rewrite: bool = False,
@@ -310,7 +310,7 @@ def __init__(
310310
Args:
311311
model: Model from which to extract features.
312312
out_indices: Output indices of the model features to extract.
313-
return_map: Return id mapping for each output index, otherwise str(index) is used.
313+
out_map: Return id mapping for each output index, otherwise str(index) is used.
314314
return_dict: Output features as a dict.
315315
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
316316
flatten_sequential arg must also be False if this is set True.
@@ -348,7 +348,7 @@ def __init__(
348348
break
349349
assert not remaining, f'Return layers ({remaining}) are not present in model'
350350
self.update(layers)
351-
self.hooks = FeatureHooks(hooks, model.named_modules(), return_map=return_map)
351+
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
352352

353353
def set_grad_checkpointing(self, enable: bool = True):
354354
self.grad_checkpointing = enable

0 commit comments

Comments
 (0)