Skip to content

Commit d061b60

Browse files
authored
[Enhancement] Simplify Config and Builder (PaddlePaddle#2897)
* Simplify Config and Builder
1 parent 6973990 commit d061b60

File tree

9 files changed

+415
-290
lines changed

9 files changed

+415
-290
lines changed

paddleseg/cvlibs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
from . import manager
1616
from . import param_init
1717
from .config import Config
18+
from .builder import Builder, SegBuilder

paddleseg/cvlibs/builder.py

Lines changed: 231 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,31 @@
1313
# limitations under the License.
1414

1515
import copy
16+
from typing import Any, Optional
1617

18+
import yaml
19+
import paddle
1720

18-
class ComponentBuilder(object):
21+
from paddleseg.cvlibs import manager, Config
22+
from paddleseg.utils import utils, logger
23+
from paddleseg.utils.utils import CachedProperty as cached_property
24+
25+
26+
class Builder(object):
1927
"""
20-
This class is responsible for building components. All component classes must be available
21-
in the list of maintained components.
28+
The base class for building components.
2229
2330
Args:
24-
com_list (list): A list of component classes.
31+
config (Config): A Config class.
32+
comp_list (list, optional): A list of component classes. Default: None
2533
"""
2634

27-
def __init__(self, com_list):
35+
def __init__(self, config: Config, comp_list: Optional[list]=None):
2836
super().__init__()
29-
self.com_list = com_list
37+
self.config = config
38+
self.comp_list = comp_list
3039

31-
def create_object(self, cfg):
40+
def build_component(self, cfg):
3241
"""
3342
Create Python object, such as model, loss, dataset, etc.
3443
"""
@@ -44,17 +53,17 @@ def create_object(self, cfg):
4453
params = {}
4554
for key, val in cfg.items():
4655
if self.is_meta_type(val):
47-
params[key] = self.create_object(val)
56+
params[key] = self.build_component(val)
4857
elif isinstance(val, list):
4958
params[key] = [
50-
self.create_object(item)
59+
self.build_component(item)
5160
if self.is_meta_type(item) else item for item in val
5261
]
5362
else:
5463
params[key] = val
5564

5665
try:
57-
obj = self.create_object_impl(com_class, **params)
66+
obj = self.build_component_impl(com_class, **params)
5867
except Exception as e:
5968
if hasattr(com_class, '__name__'):
6069
com_name = com_class.__name__
@@ -64,28 +73,16 @@ def create_object(self, cfg):
6473
f"Tried to create a {com_name} object, but the operation has failed. "
6574
"Please double check the arguments used to create the object.\n"
6675
f"The error message is: \n{str(e)}")
67-
return obj
68-
69-
def create_object_impl(self, component_class, *args, **kwargs):
70-
raise NotImplementedError
71-
72-
def load_component_class(self, cfg):
73-
raise NotImplementedError
74-
75-
@classmethod
76-
def is_meta_type(cls, obj):
77-
raise NotImplementedError
7876

77+
return obj
7978

80-
class DefaultComponentBuilder(ComponentBuilder):
81-
def create_object_impl(self, component_class, *args, **kwargs):
79+
def build_component_impl(self, component_class, *args, **kwargs):
8280
return component_class(*args, **kwargs)
8381

8482
def load_component_class(self, class_type):
85-
for com in self.com_list:
83+
for com in self.comp_list:
8684
if class_type in com.components_dict:
8785
return com[class_type]
88-
8986
raise RuntimeError("The specified component ({}) was not found.".format(
9087
class_type))
9188

@@ -94,3 +91,212 @@ def is_meta_type(cls, obj):
9491
# TODO: should we define a protocol (see https://peps.python.org/pep-0544/#defining-a-protocol)
9592
# to make it more pythonic?
9693
return isinstance(obj, dict) and 'type' in obj
94+
95+
@classmethod
96+
def show_msg(cls, name, cfg):
97+
msg = 'Use the following config to build {}\n'.format(name)
98+
msg += str(yaml.dump({name: cfg}, Dumper=utils.NoAliasDumper))
99+
logger.info(msg[0:-1])
100+
101+
102+
class SegBuilder(Builder):
103+
"""
104+
This class is responsible for building components for semantic segmentation.
105+
"""
106+
107+
def __init__(self, config, comp_list=None):
108+
if comp_list is None:
109+
comp_list = [
110+
manager.MODELS, manager.BACKBONES, manager.DATASETS,
111+
manager.TRANSFORMS, manager.LOSSES, manager.OPTIMIZERS
112+
]
113+
super().__init__(config, comp_list)
114+
115+
@cached_property
116+
def model(self) -> paddle.nn.Layer:
117+
model_cfg = self.config.model_cfg
118+
assert model_cfg != {}, \
119+
'No model specified in the configuration file.'
120+
121+
if self.config.train_dataset_cfg['type'] != 'Dataset':
122+
# check and synchronize the num_classes in model config and dataset class
123+
assert hasattr(self.train_dataset_class, 'NUM_CLASSES'), \
124+
'If train_dataset class is not `Dataset`, it must have `NUM_CLASSES` attr.'
125+
num_classes = getattr(self.train_dataset_class, 'NUM_CLASSES')
126+
if 'num_classes' in model_cfg:
127+
assert model_cfg['num_classes'] == num_classes, \
128+
'The num_classes is not consistent for model config ({}) ' \
129+
'and train_dataset class ({}) '.format(model_cfg['num_classes'], num_classes)
130+
else:
131+
logger.warning(
132+
'Add the `num_classes` in train_dataset class to '
133+
'model config. We suggest you manually set `num_classes` in model config.'
134+
)
135+
model_cfg['num_classes'] = num_classes
136+
# check and synchronize the in_channels in model config and dataset class
137+
assert hasattr(self.train_dataset_class, 'IMG_CHANNELS'), \
138+
'If train_dataset class is not `Dataset`, it must have `IMG_CHANNELS` attr.'
139+
in_channels = getattr(self.train_dataset_class, 'IMG_CHANNELS')
140+
x = utils.get_in_channels(model_cfg)
141+
if x is not None:
142+
assert x == in_channels, \
143+
'The in_channels in model config ({}) and the img_channels in train_dataset ' \
144+
'class ({}) is not consistent'.format(x, in_channels)
145+
else:
146+
model_cfg = utils.set_in_channels(model_cfg, in_channels)
147+
logger.warning(
148+
'Add the `in_channels` in train_dataset class to '
149+
'model config. We suggest you manually set `in_channels` in model config.'
150+
)
151+
152+
self.show_msg('model', model_cfg)
153+
return self.build_component(model_cfg)
154+
155+
@cached_property
156+
def optimizer(self) -> paddle.optimizer.Optimizer:
157+
opt_cfg = self.config.optimizer_cfg
158+
assert opt_cfg != {}, \
159+
'No optimizer specified in the configuration file.'
160+
# For compatibility
161+
if opt_cfg['type'] == 'adam':
162+
opt_cfg['type'] = 'Adam'
163+
if opt_cfg['type'] == 'sgd':
164+
opt_cfg['type'] = 'SGD'
165+
if opt_cfg['type'] == 'SGD' and 'momentum' in opt_cfg:
166+
opt_cfg['type'] = 'Momentum'
167+
logger.info('If the type is SGD and momentum in optimizer config, '
168+
'the type is changed to Momentum.')
169+
self.show_msg('optimizer', opt_cfg)
170+
opt = self.build_component(opt_cfg)
171+
opt = opt(self.model, self.lr_scheduler)
172+
return opt
173+
174+
@cached_property
175+
def lr_scheduler(self) -> paddle.optimizer.lr.LRScheduler:
176+
lr_cfg = self.config.lr_scheduler_cfg
177+
assert lr_cfg != {}, \
178+
'No lr_scheduler specified in the configuration file.'
179+
180+
use_warmup = False
181+
if 'warmup_iters' in lr_cfg:
182+
use_warmup = True
183+
warmup_iters = lr_cfg.pop('warmup_iters')
184+
assert 'warmup_start_lr' in lr_cfg, \
185+
"When use warmup, please set warmup_start_lr and warmup_iters in lr_scheduler"
186+
warmup_start_lr = lr_cfg.pop('warmup_start_lr')
187+
end_lr = lr_cfg['learning_rate']
188+
189+
lr_type = lr_cfg.pop('type')
190+
if lr_type == 'PolynomialDecay':
191+
iters = self.config.iters - warmup_iters if use_warmup else self.config.iters
192+
iters = max(iters, 1)
193+
lr_cfg.setdefault('decay_steps', iters)
194+
195+
try:
196+
lr_sche = getattr(paddle.optimizer.lr, lr_type)(**lr_cfg)
197+
except Exception as e:
198+
raise RuntimeError(
199+
"Create {} has failed. Please check lr_scheduler in config. "
200+
"The error message: {}".format(lr_type, e))
201+
202+
if use_warmup:
203+
lr_sche = paddle.optimizer.lr.LinearWarmup(
204+
learning_rate=lr_sche,
205+
warmup_steps=warmup_iters,
206+
start_lr=warmup_start_lr,
207+
end_lr=end_lr)
208+
209+
return lr_sche
210+
211+
@cached_property
212+
def loss(self) -> dict:
213+
loss_cfg = self.config.loss_cfg
214+
assert loss_cfg != {}, \
215+
'No loss specified in the configuration file.'
216+
return self._build_loss('loss', loss_cfg)
217+
218+
@cached_property
219+
def distill_loss(self) -> dict:
220+
loss_cfg = self.config.distill_loss_cfg
221+
assert loss_cfg != {}, \
222+
'No distill_loss specified in the configuration file.'
223+
return self._build_loss('distill_loss', loss_cfg)
224+
225+
def _build_loss(self, loss_name, loss_cfg: dict):
226+
def _check_helper(loss_cfg, ignore_index):
227+
if 'ignore_index' not in loss_cfg:
228+
loss_cfg['ignore_index'] = ignore_index
229+
logger.warning('Add the `ignore_index` in train_dataset ' \
230+
'class to {} config. We suggest you manually set ' \
231+
'`ignore_index` in {} config.'.format(loss_name, loss_name)
232+
)
233+
else:
234+
assert loss_cfg['ignore_index'] == ignore_index, \
235+
'the ignore_index in loss and train_dataset must be the same. Currently, loss ignore_index = {}, '\
236+
'train_dataset ignore_index = {}'.format(loss_cfg['ignore_index'], ignore_index)
237+
238+
# check and synchronize the ignore_index in model config and dataset class
239+
if self.config.train_dataset_cfg['type'] != 'Dataset':
240+
assert hasattr(self.train_dataset_class, 'IGNORE_INDEX'), \
241+
'If train_dataset class is not `Dataset`, it must have `IGNORE_INDEX` attr.'
242+
ignore_index = getattr(self.train_dataset_class, 'IGNORE_INDEX')
243+
for loss_cfg_i in loss_cfg['types']:
244+
if loss_cfg_i['type'] == 'MixedLoss':
245+
for loss_cfg_j in loss_cfg_i['losses']:
246+
_check_helper(loss_cfg_j, ignore_index)
247+
else:
248+
_check_helper(loss_cfg_i, ignore_index)
249+
250+
self.show_msg(loss_name, loss_cfg)
251+
loss_dict = {'coef': loss_cfg['coef'], "types": []}
252+
for item in loss_cfg['types']:
253+
loss_dict['types'].append(self.build_component(item))
254+
return loss_dict
255+
256+
@cached_property
257+
def train_dataset(self) -> paddle.io.Dataset:
258+
dataset_cfg = self.config.train_dataset_cfg
259+
assert dataset_cfg != {}, \
260+
'No train_dataset specified in the configuration file.'
261+
self.show_msg('train_dataset', dataset_cfg)
262+
dataset = self.build_component(dataset_cfg)
263+
assert len(dataset) != 0, \
264+
'The number of samples in train_dataset is 0. Please check whether the dataset is valid.'
265+
return dataset
266+
267+
@cached_property
268+
def val_dataset(self) -> paddle.io.Dataset:
269+
dataset_cfg = self.config.val_dataset_cfg
270+
assert dataset_cfg != {}, \
271+
'No val_dataset specified in the configuration file.'
272+
self.show_msg('val_dataset', dataset_cfg)
273+
dataset = self.build_component(dataset_cfg)
274+
assert len(dataset) != 0, \
275+
'The number of samples in val_dataset is 0. Please check whether the dataset is valid.'
276+
return dataset
277+
278+
@cached_property
279+
def train_dataset_class(self) -> Any:
280+
dataset_cfg = self.config.train_dataset_cfg
281+
assert dataset_cfg != {}, \
282+
'No train_dataset specified in the configuration file.'
283+
dataset_type = dataset_cfg.get('type')
284+
return self.load_component_class(dataset_type)
285+
286+
@cached_property
287+
def val_dataset_class(self) -> Any:
288+
dataset_cfg = self.config.val_dataset_cfg
289+
assert dataset_cfg != {}, \
290+
'No val_dataset specified in the configuration file.'
291+
dataset_type = dataset_cfg.get('type')
292+
return self.load_component_class(dataset_type)
293+
294+
@cached_property
295+
def val_transforms(self) -> list:
296+
dataset_cfg = self.config.val_dataset_cfg
297+
assert dataset_cfg != {}, \
298+
'No val_dataset specified in the configuration file.'
299+
transforms = []
300+
for item in dataset_cfg.get('transforms', []):
301+
transforms.append(self.build_component(item))
302+
return transforms

0 commit comments

Comments
 (0)