13
13
# limitations under the License.
14
14
15
15
import copy
16
+ from typing import Any , Optional
16
17
18
+ import yaml
19
+ import paddle
17
20
18
- class ComponentBuilder (object ):
21
+ from paddleseg .cvlibs import manager , Config
22
+ from paddleseg .utils import utils , logger
23
+ from paddleseg .utils .utils import CachedProperty as cached_property
24
+
25
+
26
+ class Builder (object ):
19
27
"""
20
- This class is responsible for building components. All component classes must be available
21
- in the list of maintained components.
28
+ The base class for building components.
22
29
23
30
Args:
24
- com_list (list): A list of component classes.
31
+ config (Config): A Config class.
32
+ comp_list (list, optional): A list of component classes. Default: None
25
33
"""
26
34
27
- def __init__ (self , com_list ):
35
+ def __init__ (self , config : Config , comp_list : Optional [ list ] = None ):
28
36
super ().__init__ ()
29
- self .com_list = com_list
37
+ self .config = config
38
+ self .comp_list = comp_list
30
39
31
- def create_object (self , cfg ):
40
+ def build_component (self , cfg ):
32
41
"""
33
42
Create Python object, such as model, loss, dataset, etc.
34
43
"""
@@ -44,17 +53,17 @@ def create_object(self, cfg):
44
53
params = {}
45
54
for key , val in cfg .items ():
46
55
if self .is_meta_type (val ):
47
- params [key ] = self .create_object (val )
56
+ params [key ] = self .build_component (val )
48
57
elif isinstance (val , list ):
49
58
params [key ] = [
50
- self .create_object (item )
59
+ self .build_component (item )
51
60
if self .is_meta_type (item ) else item for item in val
52
61
]
53
62
else :
54
63
params [key ] = val
55
64
56
65
try :
57
- obj = self .create_object_impl (com_class , ** params )
66
+ obj = self .build_component_impl (com_class , ** params )
58
67
except Exception as e :
59
68
if hasattr (com_class , '__name__' ):
60
69
com_name = com_class .__name__
@@ -64,28 +73,16 @@ def create_object(self, cfg):
64
73
f"Tried to create a { com_name } object, but the operation has failed. "
65
74
"Please double check the arguments used to create the object.\n "
66
75
f"The error message is: \n { str (e )} " )
67
- return obj
68
-
69
- def create_object_impl (self , component_class , * args , ** kwargs ):
70
- raise NotImplementedError
71
-
72
- def load_component_class (self , cfg ):
73
- raise NotImplementedError
74
-
75
- @classmethod
76
- def is_meta_type (cls , obj ):
77
- raise NotImplementedError
78
76
77
+ return obj
79
78
80
- class DefaultComponentBuilder (ComponentBuilder ):
81
- def create_object_impl (self , component_class , * args , ** kwargs ):
79
+ def build_component_impl (self , component_class , * args , ** kwargs ):
82
80
return component_class (* args , ** kwargs )
83
81
84
82
def load_component_class (self , class_type ):
85
- for com in self .com_list :
83
+ for com in self .comp_list :
86
84
if class_type in com .components_dict :
87
85
return com [class_type ]
88
-
89
86
raise RuntimeError ("The specified component ({}) was not found." .format (
90
87
class_type ))
91
88
@@ -94,3 +91,212 @@ def is_meta_type(cls, obj):
94
91
# TODO: should we define a protocol (see https://peps.python.org/pep-0544/#defining-a-protocol)
95
92
# to make it more pythonic?
96
93
return isinstance (obj , dict ) and 'type' in obj
94
+
95
+ @classmethod
96
+ def show_msg (cls , name , cfg ):
97
+ msg = 'Use the following config to build {}\n ' .format (name )
98
+ msg += str (yaml .dump ({name : cfg }, Dumper = utils .NoAliasDumper ))
99
+ logger .info (msg [0 :- 1 ])
100
+
101
+
102
+ class SegBuilder (Builder ):
103
+ """
104
+ This class is responsible for building components for semantic segmentation.
105
+ """
106
+
107
+ def __init__ (self , config , comp_list = None ):
108
+ if comp_list is None :
109
+ comp_list = [
110
+ manager .MODELS , manager .BACKBONES , manager .DATASETS ,
111
+ manager .TRANSFORMS , manager .LOSSES , manager .OPTIMIZERS
112
+ ]
113
+ super ().__init__ (config , comp_list )
114
+
115
+ @cached_property
116
+ def model (self ) -> paddle .nn .Layer :
117
+ model_cfg = self .config .model_cfg
118
+ assert model_cfg != {}, \
119
+ 'No model specified in the configuration file.'
120
+
121
+ if self .config .train_dataset_cfg ['type' ] != 'Dataset' :
122
+ # check and synchronize the num_classes in model config and dataset class
123
+ assert hasattr (self .train_dataset_class , 'NUM_CLASSES' ), \
124
+ 'If train_dataset class is not `Dataset`, it must have `NUM_CLASSES` attr.'
125
+ num_classes = getattr (self .train_dataset_class , 'NUM_CLASSES' )
126
+ if 'num_classes' in model_cfg :
127
+ assert model_cfg ['num_classes' ] == num_classes , \
128
+ 'The num_classes is not consistent for model config ({}) ' \
129
+ 'and train_dataset class ({}) ' .format (model_cfg ['num_classes' ], num_classes )
130
+ else :
131
+ logger .warning (
132
+ 'Add the `num_classes` in train_dataset class to '
133
+ 'model config. We suggest you manually set `num_classes` in model config.'
134
+ )
135
+ model_cfg ['num_classes' ] = num_classes
136
+ # check and synchronize the in_channels in model config and dataset class
137
+ assert hasattr (self .train_dataset_class , 'IMG_CHANNELS' ), \
138
+ 'If train_dataset class is not `Dataset`, it must have `IMG_CHANNELS` attr.'
139
+ in_channels = getattr (self .train_dataset_class , 'IMG_CHANNELS' )
140
+ x = utils .get_in_channels (model_cfg )
141
+ if x is not None :
142
+ assert x == in_channels , \
143
+ 'The in_channels in model config ({}) and the img_channels in train_dataset ' \
144
+ 'class ({}) is not consistent' .format (x , in_channels )
145
+ else :
146
+ model_cfg = utils .set_in_channels (model_cfg , in_channels )
147
+ logger .warning (
148
+ 'Add the `in_channels` in train_dataset class to '
149
+ 'model config. We suggest you manually set `in_channels` in model config.'
150
+ )
151
+
152
+ self .show_msg ('model' , model_cfg )
153
+ return self .build_component (model_cfg )
154
+
155
+ @cached_property
156
+ def optimizer (self ) -> paddle .optimizer .Optimizer :
157
+ opt_cfg = self .config .optimizer_cfg
158
+ assert opt_cfg != {}, \
159
+ 'No optimizer specified in the configuration file.'
160
+ # For compatibility
161
+ if opt_cfg ['type' ] == 'adam' :
162
+ opt_cfg ['type' ] = 'Adam'
163
+ if opt_cfg ['type' ] == 'sgd' :
164
+ opt_cfg ['type' ] = 'SGD'
165
+ if opt_cfg ['type' ] == 'SGD' and 'momentum' in opt_cfg :
166
+ opt_cfg ['type' ] = 'Momentum'
167
+ logger .info ('If the type is SGD and momentum in optimizer config, '
168
+ 'the type is changed to Momentum.' )
169
+ self .show_msg ('optimizer' , opt_cfg )
170
+ opt = self .build_component (opt_cfg )
171
+ opt = opt (self .model , self .lr_scheduler )
172
+ return opt
173
+
174
+ @cached_property
175
+ def lr_scheduler (self ) -> paddle .optimizer .lr .LRScheduler :
176
+ lr_cfg = self .config .lr_scheduler_cfg
177
+ assert lr_cfg != {}, \
178
+ 'No lr_scheduler specified in the configuration file.'
179
+
180
+ use_warmup = False
181
+ if 'warmup_iters' in lr_cfg :
182
+ use_warmup = True
183
+ warmup_iters = lr_cfg .pop ('warmup_iters' )
184
+ assert 'warmup_start_lr' in lr_cfg , \
185
+ "When use warmup, please set warmup_start_lr and warmup_iters in lr_scheduler"
186
+ warmup_start_lr = lr_cfg .pop ('warmup_start_lr' )
187
+ end_lr = lr_cfg ['learning_rate' ]
188
+
189
+ lr_type = lr_cfg .pop ('type' )
190
+ if lr_type == 'PolynomialDecay' :
191
+ iters = self .config .iters - warmup_iters if use_warmup else self .config .iters
192
+ iters = max (iters , 1 )
193
+ lr_cfg .setdefault ('decay_steps' , iters )
194
+
195
+ try :
196
+ lr_sche = getattr (paddle .optimizer .lr , lr_type )(** lr_cfg )
197
+ except Exception as e :
198
+ raise RuntimeError (
199
+ "Create {} has failed. Please check lr_scheduler in config. "
200
+ "The error message: {}" .format (lr_type , e ))
201
+
202
+ if use_warmup :
203
+ lr_sche = paddle .optimizer .lr .LinearWarmup (
204
+ learning_rate = lr_sche ,
205
+ warmup_steps = warmup_iters ,
206
+ start_lr = warmup_start_lr ,
207
+ end_lr = end_lr )
208
+
209
+ return lr_sche
210
+
211
+ @cached_property
212
+ def loss (self ) -> dict :
213
+ loss_cfg = self .config .loss_cfg
214
+ assert loss_cfg != {}, \
215
+ 'No loss specified in the configuration file.'
216
+ return self ._build_loss ('loss' , loss_cfg )
217
+
218
+ @cached_property
219
+ def distill_loss (self ) -> dict :
220
+ loss_cfg = self .config .distill_loss_cfg
221
+ assert loss_cfg != {}, \
222
+ 'No distill_loss specified in the configuration file.'
223
+ return self ._build_loss ('distill_loss' , loss_cfg )
224
+
225
+ def _build_loss (self , loss_name , loss_cfg : dict ):
226
+ def _check_helper (loss_cfg , ignore_index ):
227
+ if 'ignore_index' not in loss_cfg :
228
+ loss_cfg ['ignore_index' ] = ignore_index
229
+ logger .warning ('Add the `ignore_index` in train_dataset ' \
230
+ 'class to {} config. We suggest you manually set ' \
231
+ '`ignore_index` in {} config.' .format (loss_name , loss_name )
232
+ )
233
+ else :
234
+ assert loss_cfg ['ignore_index' ] == ignore_index , \
235
+ 'the ignore_index in loss and train_dataset must be the same. Currently, loss ignore_index = {}, ' \
236
+ 'train_dataset ignore_index = {}' .format (loss_cfg ['ignore_index' ], ignore_index )
237
+
238
+ # check and synchronize the ignore_index in model config and dataset class
239
+ if self .config .train_dataset_cfg ['type' ] != 'Dataset' :
240
+ assert hasattr (self .train_dataset_class , 'IGNORE_INDEX' ), \
241
+ 'If train_dataset class is not `Dataset`, it must have `IGNORE_INDEX` attr.'
242
+ ignore_index = getattr (self .train_dataset_class , 'IGNORE_INDEX' )
243
+ for loss_cfg_i in loss_cfg ['types' ]:
244
+ if loss_cfg_i ['type' ] == 'MixedLoss' :
245
+ for loss_cfg_j in loss_cfg_i ['losses' ]:
246
+ _check_helper (loss_cfg_j , ignore_index )
247
+ else :
248
+ _check_helper (loss_cfg_i , ignore_index )
249
+
250
+ self .show_msg (loss_name , loss_cfg )
251
+ loss_dict = {'coef' : loss_cfg ['coef' ], "types" : []}
252
+ for item in loss_cfg ['types' ]:
253
+ loss_dict ['types' ].append (self .build_component (item ))
254
+ return loss_dict
255
+
256
+ @cached_property
257
+ def train_dataset (self ) -> paddle .io .Dataset :
258
+ dataset_cfg = self .config .train_dataset_cfg
259
+ assert dataset_cfg != {}, \
260
+ 'No train_dataset specified in the configuration file.'
261
+ self .show_msg ('train_dataset' , dataset_cfg )
262
+ dataset = self .build_component (dataset_cfg )
263
+ assert len (dataset ) != 0 , \
264
+ 'The number of samples in train_dataset is 0. Please check whether the dataset is valid.'
265
+ return dataset
266
+
267
+ @cached_property
268
+ def val_dataset (self ) -> paddle .io .Dataset :
269
+ dataset_cfg = self .config .val_dataset_cfg
270
+ assert dataset_cfg != {}, \
271
+ 'No val_dataset specified in the configuration file.'
272
+ self .show_msg ('val_dataset' , dataset_cfg )
273
+ dataset = self .build_component (dataset_cfg )
274
+ assert len (dataset ) != 0 , \
275
+ 'The number of samples in val_dataset is 0. Please check whether the dataset is valid.'
276
+ return dataset
277
+
278
+ @cached_property
279
+ def train_dataset_class (self ) -> Any :
280
+ dataset_cfg = self .config .train_dataset_cfg
281
+ assert dataset_cfg != {}, \
282
+ 'No train_dataset specified in the configuration file.'
283
+ dataset_type = dataset_cfg .get ('type' )
284
+ return self .load_component_class (dataset_type )
285
+
286
+ @cached_property
287
+ def val_dataset_class (self ) -> Any :
288
+ dataset_cfg = self .config .val_dataset_cfg
289
+ assert dataset_cfg != {}, \
290
+ 'No val_dataset specified in the configuration file.'
291
+ dataset_type = dataset_cfg .get ('type' )
292
+ return self .load_component_class (dataset_type )
293
+
294
+ @cached_property
295
+ def val_transforms (self ) -> list :
296
+ dataset_cfg = self .config .val_dataset_cfg
297
+ assert dataset_cfg != {}, \
298
+ 'No val_dataset specified in the configuration file.'
299
+ transforms = []
300
+ for item in dataset_cfg .get ('transforms' , []):
301
+ transforms .append (self .build_component (item ))
302
+ return transforms
0 commit comments