3
3
import re
4
4
from collections import defaultdict
5
5
from itertools import chain
6
- from typing import Callable , Union , Dict
6
+ from typing import Any , Callable , Dict , Iterator , Tuple , Type , Union
7
7
8
8
import torch
9
9
from torch import nn as nn
13
13
'group_with_matcher' , 'group_modules' , 'group_parameters' , 'flatten_modules' , 'checkpoint_seq' ]
14
14
15
15
16
- def model_parameters (model , exclude_head = False ):
16
+ def model_parameters (model : nn . Module , exclude_head : bool = False ):
17
17
if exclude_head :
18
18
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
19
19
return [p for p in model .parameters ()][:- 2 ]
20
20
else :
21
21
return model .parameters ()
22
22
23
23
24
- def named_apply (fn : Callable , module : nn .Module , name = '' , depth_first = True , include_root = False ) -> nn .Module :
24
+ def named_apply (
25
+ fn : Callable ,
26
+ module : nn .Module , name = '' ,
27
+ depth_first : bool = True ,
28
+ include_root : bool = False ,
29
+ ) -> nn .Module :
25
30
if not depth_first and include_root :
26
31
fn (module = module , name = name )
27
32
for child_name , child_module in module .named_children ():
@@ -32,7 +37,12 @@ def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, incl
32
37
return module
33
38
34
39
35
- def named_modules (module : nn .Module , name = '' , depth_first = True , include_root = False ):
40
+ def named_modules (
41
+ module : nn .Module ,
42
+ name : str = '' ,
43
+ depth_first : bool = True ,
44
+ include_root : bool = False ,
45
+ ):
36
46
if not depth_first and include_root :
37
47
yield name , module
38
48
for child_name , child_module in module .named_children ():
@@ -43,7 +53,12 @@ def named_modules(module: nn.Module, name='', depth_first=True, include_root=Fal
43
53
yield name , module
44
54
45
55
46
- def named_modules_with_params (module : nn .Module , name = '' , depth_first = True , include_root = False ):
56
+ def named_modules_with_params (
57
+ module : nn .Module ,
58
+ name : str = '' ,
59
+ depth_first : bool = True ,
60
+ include_root : bool = False ,
61
+ ):
47
62
if module ._parameters and not depth_first and include_root :
48
63
yield name , module
49
64
for child_name , child_module in module .named_children ():
@@ -58,9 +73,9 @@ def named_modules_with_params(module: nn.Module, name='', depth_first=True, incl
58
73
59
74
60
75
def group_with_matcher (
61
- named_objects ,
76
+ named_objects : Iterator [ Tuple [ str , Any ]] ,
62
77
group_matcher : Union [Dict , Callable ],
63
- output_values : bool = False ,
78
+ return_values : bool = False ,
64
79
reverse : bool = False
65
80
):
66
81
if isinstance (group_matcher , dict ):
@@ -96,7 +111,7 @@ def _get_grouping(name):
96
111
# map layers into groups via ordinals (ints or tuples of ints) from matcher
97
112
grouping = defaultdict (list )
98
113
for k , v in named_objects :
99
- grouping [_get_grouping (k )].append (v if output_values else k )
114
+ grouping [_get_grouping (k )].append (v if return_values else k )
100
115
101
116
# remap to integers
102
117
layer_id_to_param = defaultdict (list )
@@ -107,7 +122,7 @@ def _get_grouping(name):
107
122
layer_id_to_param [lid ].extend (grouping [k ])
108
123
109
124
if reverse :
110
- assert not output_values , "reverse mapping only sensible for name output"
125
+ assert not return_values , "reverse mapping only sensible for name output"
111
126
# output reverse mapping
112
127
param_to_layer_id = {}
113
128
for lid , lm in layer_id_to_param .items ():
@@ -121,24 +136,29 @@ def _get_grouping(name):
121
136
def group_parameters (
122
137
module : nn .Module ,
123
138
group_matcher ,
124
- output_values = False ,
125
- reverse = False ,
139
+ return_values : bool = False ,
140
+ reverse : bool = False ,
126
141
):
127
142
return group_with_matcher (
128
- module .named_parameters (), group_matcher , output_values = output_values , reverse = reverse )
143
+ module .named_parameters (), group_matcher , return_values = return_values , reverse = reverse )
129
144
130
145
131
146
def group_modules (
132
147
module : nn .Module ,
133
148
group_matcher ,
134
- output_values = False ,
135
- reverse = False ,
149
+ return_values : bool = False ,
150
+ reverse : bool = False ,
136
151
):
137
152
return group_with_matcher (
138
- named_modules_with_params (module ), group_matcher , output_values = output_values , reverse = reverse )
153
+ named_modules_with_params (module ), group_matcher , return_values = return_values , reverse = reverse )
139
154
140
155
141
- def flatten_modules (named_modules , depth = 1 , prefix = '' , module_types = 'sequential' ):
156
+ def flatten_modules (
157
+ named_modules : Iterator [Tuple [str , nn .Module ]],
158
+ depth : int = 1 ,
159
+ prefix : Union [str , Tuple [str , ...]] = '' ,
160
+ module_types : Union [str , Tuple [Type [nn .Module ]]] = 'sequential' ,
161
+ ):
142
162
prefix_is_tuple = isinstance (prefix , tuple )
143
163
if isinstance (module_types , str ):
144
164
if module_types == 'container' :
0 commit comments