|
| 1 | +import os |
| 2 | +import sys |
| 3 | +import math |
| 4 | +import pprint |
| 5 | + |
| 6 | +import torch |
| 7 | + |
| 8 | +from torchdrug import core, models, tasks, datasets, utils |
| 9 | +from torchdrug.utils import comm |
| 10 | + |
| 11 | +sys.path.append(os.path.dirname(os.path.dirname(__file__))) |
| 12 | +import util |
| 13 | + |
| 14 | + |
| 15 | +def save(solver, path): |
| 16 | + if isinstance(solver.model, tasks.Unsupervised): |
| 17 | + model = solver.model.model.model |
| 18 | + else: |
| 19 | + model = solver.model.model |
| 20 | + |
| 21 | + if comm.get_rank() == 0: |
| 22 | + logger.warning("Save checkpoint to %s" % path) |
| 23 | + path = os.path.expanduser(path) |
| 24 | + if comm.get_rank() == 0: |
| 25 | + torch.save(model.state_dict(), path) |
| 26 | + comm.synchronize() |
| 27 | + |
| 28 | + |
| 29 | +if __name__ == "__main__": |
| 30 | + args, vars = util.parse_args() |
| 31 | + cfg = util.load_config(args.config, context=vars) |
| 32 | + working_dir = util.create_working_directory(cfg) |
| 33 | + |
| 34 | + torch.manual_seed(args.seed + comm.get_rank()) |
| 35 | + |
| 36 | + logger = util.get_root_logger() |
| 37 | + if comm.get_rank() == 0: |
| 38 | + logger.warning("Config file: %s" % args.config) |
| 39 | + logger.warning(pprint.pformat(cfg)) |
| 40 | + |
| 41 | + species_start = cfg.dataset.get("species_start", 0) |
| 42 | + species_end = cfg.dataset.get("species_end", 0) |
| 43 | + assert species_end >= species_start |
| 44 | + if species_end > species_start: |
| 45 | + cfg.dataset.species_id = species_start |
| 46 | + cfg.dataset.split_id = 0 |
| 47 | + cfg.dataset.pop("species_start") |
| 48 | + cfg.dataset.pop("species_end") |
| 49 | + dataset = core.Configurable.load_config_dict(cfg.dataset) |
| 50 | + solver = util.build_pretrain_solver(cfg, dataset) |
| 51 | + |
| 52 | + step = cfg.get("save_interval", 1) |
| 53 | + for i in range(0, cfg.train.num_epoch, step): |
| 54 | + kwargs = cfg.train.copy() |
| 55 | + kwargs["num_epoch"] = min(step, cfg.train.num_epoch - i) |
| 56 | + |
| 57 | + if species_end == species_start: |
| 58 | + solver.train(**kwargs) |
| 59 | + else: |
| 60 | + for species_id in range(species_start, species_end): |
| 61 | + for split_id in range(dataset.species_nsplit[species_id]): |
| 62 | + cfg.dataset.species_id = species_id |
| 63 | + cfg.dataset.split_id = split_id |
| 64 | + dataset = core.Configurable.load_config_dict(cfg.dataset) |
| 65 | + logger.warning('Epoch: {}\tSpecies id: {}\tSplit id: {}\tSplit length: {}'.format( |
| 66 | + i, species_id, split_id, len(dataset))) |
| 67 | + solver.train_set = dataset |
| 68 | + solver.train(**kwargs) |
| 69 | + |
| 70 | + save(solver, "model_epoch_%d.pth" % (i + kwargs["num_epoch"])) |
0 commit comments