Skip to content

Commit 1e5ab6d

Browse files
committed
fix a bug in last commit
1 parent 34d4637 commit 1e5ab6d

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

util.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ def build_downstream_solver(cfg, dataset):
100100
logger.warning(dataset)
101101
logger.warning("#train: %d, #valid: %d, #test: %d" % (len(train_set), len(valid_set), len(test_set)))
102102

103+
if cfg.task['class'] == 'MultipleBinaryClassification':
104+
cfg.task.task = [_ for _ in range(len(dataset.tasks))]
105+
else:
106+
cfg.task.task = dataset.tasks
103107
task = core.Configurable.load_config_dict(cfg.task)
104108
if not "lr_ratio" in cfg:
105109
cfg.optimizer.params = task.parameters()
@@ -112,12 +116,12 @@ def build_downstream_solver(cfg, dataset):
112116
optimizer = core.Configurable.load_config_dict(cfg.optimizer)
113117
solver = core.Engine(task, train_set, valid_set, test_set, optimizer, **cfg.engine)
114118

115-
if "checkpoint" in cfg:
119+
if cfg.get("checkpoint") is not None:
116120
solver.load(cfg.checkpoint)
117121

118-
if "model_checkpoint" in cfg:
122+
if cfg.get("model_checkpoint") is not None:
119123
cfg.model_checkpoint = os.path.expanduser(cfg.model_checkpoint)
120124
model_dict = torch.load(cfg.model_checkpoint, map_location=torch.device('cpu'))
121125
task.model.load_state_dict(model_dict)
122-
126+
123127
return solver

0 commit comments

Comments
 (0)