|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "execution_count": 36, |
| 6 | + "metadata": {}, |
| 7 | + "outputs": [], |
| 8 | + "source": [ |
| 9 | + "from __future__ import absolute_import\n", |
| 10 | + "from __future__ import division\n", |
| 11 | + "from __future__ import print_function\n", |
| 12 | + "\n", |
| 13 | + "import argparse\n", |
| 14 | + "import os\n", |
| 15 | + "import pprint\n", |
| 16 | + "import shutil\n", |
| 17 | + "\n", |
| 18 | + "import torch\n", |
| 19 | + "# import torch.nn.parallel\n", |
| 20 | + "import torch.backends.cudnn as cudnn\n", |
| 21 | + "import torch.optim\n", |
| 22 | + "import torchvision.transforms as transforms\n", |
| 23 | + "from tensorboardX import SummaryWriter\n", |
| 24 | + "\n", |
| 25 | + "# import _init_paths\n", |
| 26 | + "from lib.core.config import config\n", |
| 27 | + "from lib.core.config import update_config\n", |
| 28 | + "from lib.core.config import update_dir\n", |
| 29 | + "from lib.core.config import get_model_name\n", |
| 30 | + "from lib.core.loss import JointsMSELoss\n", |
| 31 | + "from lib.core.function import train\n", |
| 32 | + "from lib.core.function import validate\n", |
| 33 | + "from lib.utils.utils import get_optimizer\n", |
| 34 | + "from lib.utils.utils import save_checkpoint\n", |
| 35 | + "from lib.utils.utils import create_logger\n", |
| 36 | + "\n", |
| 37 | + "import lib.dataset as dataset\n", |
| 38 | + "import lib.models as models\n", |
| 39 | + "\n", |
| 40 | + "\n", |
| 41 | + "def parse_args():\n", |
| 42 | + " parser = argparse.ArgumentParser(description='Train keypoints network')\n", |
| 43 | + " # general\n", |
| 44 | + " parser.add_argument('--cfg',\n", |
| 45 | + " help='experiment configure file name',\n", |
| 46 | + " required=True,\n", |
| 47 | + " type=str)\n", |
| 48 | + "\n", |
| 49 | + " args, rest = parser.parse_known_args()\n", |
| 50 | + " # update config\n", |
| 51 | + " update_config(args.cfg)\n", |
| 52 | + "\n", |
| 53 | + " # training\n", |
| 54 | + " parser.add_argument('--frequent',\n", |
| 55 | + " help='frequency of logging',\n", |
| 56 | + " default=config.PRINT_FREQ,\n", |
| 57 | + " type=int)\n", |
| 58 | + " parser.add_argument('--gpus',\n", |
| 59 | + " help='gpus',\n", |
| 60 | + " type=str)\n", |
| 61 | + " parser.add_argument('--workers',\n", |
| 62 | + " help='num of dataloader workers',\n", |
| 63 | + " type=int)\n", |
| 64 | + "\n", |
| 65 | + " args = parser.parse_args()\n", |
| 66 | + "\n", |
| 67 | + " return args\n", |
| 68 | + "\n", |
| 69 | + "\n", |
| 70 | + "def reset_config(config, args):\n", |
| 71 | + " if args.gpus:\n", |
| 72 | + " config.GPUS = args.gpus\n", |
| 73 | + " if args.workers:\n", |
| 74 | + " config.WORKERS = args.workers\n", |
| 75 | + "\n", |
| 76 | + "\n", |
| 77 | + "def main():\n", |
| 78 | + " #args = parse_args()\n", |
| 79 | + " with open('experiments/simulated/128x128_d256x3_adam_lr1e-3.yaml') as file:\n", |
| 80 | + " args = yaml.load(file, Loader=yaml.FullLoader)\n", |
| 81 | + " config = dotdict(args)\n", |
| 82 | + " \n", |
| 83 | + " #reset_config(config, args)\n", |
| 84 | + "\n", |
| 85 | + " logger, final_output_dir, tb_log_dir = create_logger(\n", |
| 86 | + " config, args.cfg, 'train')\n", |
| 87 | + "\n", |
| 88 | + " logger.info(pprint.pformat(args))\n", |
| 89 | + " logger.info(pprint.pformat(config))\n", |
| 90 | + "\n", |
| 91 | + " # cudnn related setting\n", |
| 92 | + " cudnn.benchmark = config.CUDNN.BENCHMARK\n", |
| 93 | + " torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC\n", |
| 94 | + " torch.backends.cudnn.enabled = config.CUDNN.ENABLED\n", |
| 95 | + "\n", |
| 96 | + " model = eval('models.'+config.MODEL.NAME+'.get_neuron_net')(\n", |
| 97 | + " config, is_train=True\n", |
| 98 | + " )\n", |
| 99 | + "\n", |
| 100 | + " # copy model file\n", |
| 101 | + " this_dir = os.path.dirname(__file__)\n", |
| 102 | + " shutil.copy2(\n", |
| 103 | + " os.path.join(this_dir, '../lib/models', config.MODEL.NAME + '.py'),\n", |
| 104 | + " final_output_dir)\n", |
| 105 | + "\n", |
| 106 | + " writer_dict = {\n", |
| 107 | + " 'writer': SummaryWriter(log_dir=tb_log_dir),\n", |
| 108 | + " 'train_global_steps': 0,\n", |
| 109 | + " 'valid_global_steps': 0,\n", |
| 110 | + " 'vis_global_steps': 0,\n", |
| 111 | + " }\n", |
| 112 | + "\n", |
| 113 | + " dump_input = torch.rand((config.TRAIN.BATCH_SIZE,\n", |
| 114 | + " 3,\n", |
| 115 | + " config.MODEL.IMAGE_SIZE[1],\n", |
| 116 | + " config.MODEL.IMAGE_SIZE[0]))\n", |
| 117 | + " writer_dict['writer'].add_graph(model, (dump_input, ), verbose=False)\n", |
| 118 | + "\n", |
| 119 | + " gpus = [int(i) for i in config.GPUS.split(',')]\n", |
| 120 | + " model = torch.nn.DataParallel(model, device_ids=gpus).cuda()\n", |
| 121 | + "\n", |
| 122 | + " # define loss function (criterion) and optimizer\n", |
| 123 | + " criterion = JointsMSELoss(\n", |
| 124 | + " use_target_weight=config.LOSS.USE_TARGET_WEIGHT\n", |
| 125 | + " ).cuda()\n", |
| 126 | + "\n", |
| 127 | + " optimizer = get_optimizer(config, model)\n", |
| 128 | + "\n", |
| 129 | + " lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(\n", |
| 130 | + " optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR\n", |
| 131 | + " )\n", |
| 132 | + "\n", |
| 133 | + " # Data loading code\n", |
| 134 | + " normalize = transforms.Normalize(mean=[0.1440, 0.1440, 0.1440],\n", |
| 135 | + " std=[19.0070, 19.0070, 19.0070])\n", |
| 136 | + " train_dataset = eval('dataset.'+config.DATASET.DATASET)(\n", |
| 137 | + " config,\n", |
| 138 | + " config.DATASET.ROOT,\n", |
| 139 | + " config.DATASET.TRAIN_SET,\n", |
| 140 | + " True,\n", |
| 141 | + " transforms.Compose([\n", |
| 142 | + " transforms.ToTensor(),\n", |
| 143 | + " normalize,\n", |
| 144 | + " ])\n", |
| 145 | + " )\n", |
| 146 | + " valid_dataset = eval('dataset.'+config.DATASET.DATASET)(\n", |
| 147 | + " config,\n", |
| 148 | + " config.DATASET.ROOT,\n", |
| 149 | + " config.DATASET.TEST_SET,\n", |
| 150 | + " False,\n", |
| 151 | + " transforms.Compose([\n", |
| 152 | + " transforms.ToTensor(),\n", |
| 153 | + " normalize,\n", |
| 154 | + " ])\n", |
| 155 | + " )\n", |
| 156 | + "\n", |
| 157 | + " train_loader = torch.utils.data.DataLoader(\n", |
| 158 | + " train_dataset,\n", |
| 159 | + " batch_size=config.TRAIN.BATCH_SIZE*len(gpus),\n", |
| 160 | + " shuffle=config.TRAIN.SHUFFLE,\n", |
| 161 | + " num_workers=config.WORKERS,\n", |
| 162 | + " pin_memory=True\n", |
| 163 | + " )\n", |
| 164 | + " valid_loader = torch.utils.data.DataLoader(\n", |
| 165 | + " valid_dataset,\n", |
| 166 | + " batch_size=config.TEST.BATCH_SIZE*len(gpus),\n", |
| 167 | + " shuffle=False,\n", |
| 168 | + " num_workers=config.WORKERS,\n", |
| 169 | + " pin_memory=True\n", |
| 170 | + " )\n", |
| 171 | + "\n", |
| 172 | + " best_perf = 0.0\n", |
| 173 | + " best_model = False\n", |
| 174 | + " for epoch in range(config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH):\n", |
| 175 | + " lr_scheduler.step()\n", |
| 176 | + "\n", |
| 177 | + " # train for one epoch\n", |
| 178 | + " train(config, train_loader, model, criterion, optimizer, epoch,\n", |
| 179 | + " final_output_dir, tb_log_dir, writer_dict)\n", |
| 180 | + "\n", |
| 181 | + "\n", |
| 182 | + " # evaluate on validation set\n", |
| 183 | + " perf_indicator = validate(config, valid_loader, valid_dataset, model,\n", |
| 184 | + " criterion, final_output_dir, tb_log_dir,\n", |
| 185 | + " writer_dict)\n", |
| 186 | + "\n", |
| 187 | + " if perf_indicator > best_perf:\n", |
| 188 | + " best_perf = perf_indicator\n", |
| 189 | + " best_model = True\n", |
| 190 | + " else:\n", |
| 191 | + " best_model = False\n", |
| 192 | + "\n", |
| 193 | + " logger.info('=> saving checkpoint to {}'.format(final_output_dir))\n", |
| 194 | + " save_checkpoint({\n", |
| 195 | + " 'epoch': epoch + 1,\n", |
| 196 | + " 'model': get_model_name(config),\n", |
| 197 | + " 'state_dict': model.state_dict(),\n", |
| 198 | + " 'perf': perf_indicator,\n", |
| 199 | + " 'optimizer': optimizer.state_dict(),\n", |
| 200 | + " }, best_model, final_output_dir)\n", |
| 201 | + "\n", |
| 202 | + " final_model_state_file = os.path.join(final_output_dir,\n", |
| 203 | + " 'final_state.pth.tar')\n", |
| 204 | + " logger.info('saving final model state to {}'.format(\n", |
| 205 | + " final_model_state_file))\n", |
| 206 | + " torch.save(model.module.state_dict(), final_model_state_file)\n", |
| 207 | + " writer_dict['writer'].close()\n" |
| 208 | + ] |
| 209 | + }, |
| 210 | + { |
| 211 | + "cell_type": "code", |
| 212 | + "execution_count": 33, |
| 213 | + "metadata": {}, |
| 214 | + "outputs": [], |
| 215 | + "source": [ |
| 216 | + "import yaml\n", |
| 217 | + " \n", |
| 218 | + "with open('experiments/simulated/128x128_d256x3_adam_lr1e-3.yaml') as file:\n", |
| 219 | + " args = yaml.load(file, Loader=yaml.FullLoader)\n", |
| 220 | + "\n", |
| 221 | + "class dotdict(dict):\n", |
| 222 | + " \"\"\"dot.notation access to dictionary attributes\"\"\"\n", |
| 223 | + " __getattr__ = dict.get\n", |
| 224 | + " __setattr__ = dict.__setitem__\n", |
| 225 | + " __delattr__ = dict.__delitem__\n", |
| 226 | + "\n", |
| 227 | + "args = dotdict(args)\n", |
| 228 | + "args" |
| 229 | + ] |
| 230 | + }, |
| 231 | + { |
| 232 | + "cell_type": "code", |
| 233 | + "execution_count": null, |
| 234 | + "metadata": {}, |
| 235 | + "outputs": [], |
| 236 | + "source": [] |
| 237 | + } |
| 238 | + ], |
| 239 | + "metadata": { |
| 240 | + "kernelspec": { |
| 241 | + "display_name": "pyrooz", |
| 242 | + "language": "python", |
| 243 | + "name": "pyrooz" |
| 244 | + }, |
| 245 | + "language_info": { |
| 246 | + "codemirror_mode": { |
| 247 | + "name": "ipython", |
| 248 | + "version": 3 |
| 249 | + }, |
| 250 | + "file_extension": ".py", |
| 251 | + "mimetype": "text/x-python", |
| 252 | + "name": "python", |
| 253 | + "nbconvert_exporter": "python", |
| 254 | + "pygments_lexer": "ipython3", |
| 255 | + "version": "3.6.9" |
| 256 | + } |
| 257 | + }, |
| 258 | + "nbformat": 4, |
| 259 | + "nbformat_minor": 2 |
| 260 | +} |
0 commit comments