@@ -100,7 +100,7 @@ def __init__(
100
100
self ,
101
101
hooks : Sequence [str ],
102
102
named_modules : dict ,
103
- return_map : Sequence [Union [int , str ]] = None ,
103
+ out_map : Sequence [Union [int , str ]] = None ,
104
104
default_hook_type : str = 'forward' ,
105
105
):
106
106
# setup feature hooks
@@ -109,7 +109,7 @@ def __init__(
109
109
for i , h in enumerate (hooks ):
110
110
hook_name = h ['module' ]
111
111
m = modules [hook_name ]
112
- hook_id = return_map [i ] if return_map else hook_name
112
+ hook_id = out_map [i ] if out_map else hook_name
113
113
hook_fn = partial (self ._collect_output_hook , hook_id )
114
114
hook_type = h .get ('hook_type' , default_hook_type )
115
115
if hook_type == 'forward_pre' :
@@ -155,11 +155,11 @@ def _get_feature_info(net, out_indices):
155
155
assert False , "Provided feature_info is not valid"
156
156
157
157
158
- def _get_return_layers (feature_info , return_map ):
158
+ def _get_return_layers (feature_info , out_map ):
159
159
module_names = feature_info .module_name ()
160
160
return_layers = {}
161
161
for i , name in enumerate (module_names ):
162
- return_layers [name ] = return_map [i ] if return_map is not None else feature_info .out_indices [i ]
162
+ return_layers [name ] = out_map [i ] if out_map is not None else feature_info .out_indices [i ]
163
163
return return_layers
164
164
165
165
@@ -182,7 +182,7 @@ def __init__(
182
182
self ,
183
183
model : nn .Module ,
184
184
out_indices : Tuple [int , ...] = (0 , 1 , 2 , 3 , 4 ),
185
- return_map : Sequence [Union [int , str ]] = None ,
185
+ out_map : Sequence [Union [int , str ]] = None ,
186
186
output_fmt : str = 'NCHW' ,
187
187
feature_concat : bool = False ,
188
188
flatten_sequential : bool = False ,
@@ -191,7 +191,7 @@ def __init__(
191
191
Args:
192
192
model: Model from which to extract features.
193
193
out_indices: Output indices of the model features to extract.
194
- return_map : Return id mapping for each output index, otherwise str(index) is used.
194
+ out_map : Return id mapping for each output index, otherwise str(index) is used.
195
195
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
196
196
first element e.g. `x[0]`
197
197
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
@@ -203,7 +203,7 @@ def __init__(
203
203
self .grad_checkpointing = False
204
204
self .return_layers = {}
205
205
206
- return_layers = _get_return_layers (self .feature_info , return_map )
206
+ return_layers = _get_return_layers (self .feature_info , out_map )
207
207
modules = _module_list (model , flatten_sequential = flatten_sequential )
208
208
remaining = set (return_layers .keys ())
209
209
layers = OrderedDict ()
@@ -298,7 +298,7 @@ def __init__(
298
298
self ,
299
299
model : nn .Module ,
300
300
out_indices : Tuple [int , ...] = (0 , 1 , 2 , 3 , 4 ),
301
- return_map : Sequence [Union [int , str ]] = None ,
301
+ out_map : Sequence [Union [int , str ]] = None ,
302
302
return_dict : bool = False ,
303
303
output_fmt : str = 'NCHW' ,
304
304
no_rewrite : bool = False ,
@@ -310,7 +310,7 @@ def __init__(
310
310
Args:
311
311
model: Model from which to extract features.
312
312
out_indices: Output indices of the model features to extract.
313
- return_map : Return id mapping for each output index, otherwise str(index) is used.
313
+ out_map : Return id mapping for each output index, otherwise str(index) is used.
314
314
return_dict: Output features as a dict.
315
315
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
316
316
flatten_sequential arg must also be False if this is set True.
@@ -348,7 +348,7 @@ def __init__(
348
348
break
349
349
assert not remaining , f'Return layers ({ remaining } ) are not present in model'
350
350
self .update (layers )
351
- self .hooks = FeatureHooks (hooks , model .named_modules (), return_map = return_map )
351
+ self .hooks = FeatureHooks (hooks , model .named_modules (), out_map = out_map )
352
352
353
353
def set_grad_checkpointing (self , enable : bool = True ):
354
354
self .grad_checkpointing = enable
0 commit comments