Skip to content

Commit 429943b

Browse files
Add files via upload
1 parent f7958b4 commit 429943b

File tree

1 file changed

+260
-0
lines changed

1 file changed

+260
-0
lines changed

Generating_figure.ipynb

+260
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 36,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"from __future__ import absolute_import\n",
10+
"from __future__ import division\n",
11+
"from __future__ import print_function\n",
12+
"\n",
13+
"import argparse\n",
14+
"import os\n",
15+
"import pprint\n",
16+
"import shutil\n",
17+
"\n",
18+
"import torch\n",
19+
"# import torch.nn.parallel\n",
20+
"import torch.backends.cudnn as cudnn\n",
21+
"import torch.optim\n",
22+
"import torchvision.transforms as transforms\n",
23+
"from tensorboardX import SummaryWriter\n",
24+
"\n",
25+
"# import _init_paths\n",
26+
"from lib.core.config import config\n",
27+
"from lib.core.config import update_config\n",
28+
"from lib.core.config import update_dir\n",
29+
"from lib.core.config import get_model_name\n",
30+
"from lib.core.loss import JointsMSELoss\n",
31+
"from lib.core.function import train\n",
32+
"from lib.core.function import validate\n",
33+
"from lib.utils.utils import get_optimizer\n",
34+
"from lib.utils.utils import save_checkpoint\n",
35+
"from lib.utils.utils import create_logger\n",
36+
"\n",
37+
"import lib.dataset as dataset\n",
38+
"import lib.models as models\n",
39+
"\n",
40+
"\n",
41+
"def parse_args():\n",
42+
" parser = argparse.ArgumentParser(description='Train keypoints network')\n",
43+
" # general\n",
44+
" parser.add_argument('--cfg',\n",
45+
" help='experiment configure file name',\n",
46+
" required=True,\n",
47+
" type=str)\n",
48+
"\n",
49+
" args, rest = parser.parse_known_args()\n",
50+
" # update config\n",
51+
" update_config(args.cfg)\n",
52+
"\n",
53+
" # training\n",
54+
" parser.add_argument('--frequent',\n",
55+
" help='frequency of logging',\n",
56+
" default=config.PRINT_FREQ,\n",
57+
" type=int)\n",
58+
" parser.add_argument('--gpus',\n",
59+
" help='gpus',\n",
60+
" type=str)\n",
61+
" parser.add_argument('--workers',\n",
62+
" help='num of dataloader workers',\n",
63+
" type=int)\n",
64+
"\n",
65+
" args = parser.parse_args()\n",
66+
"\n",
67+
" return args\n",
68+
"\n",
69+
"\n",
70+
"def reset_config(config, args):\n",
71+
" if args.gpus:\n",
72+
" config.GPUS = args.gpus\n",
73+
" if args.workers:\n",
74+
" config.WORKERS = args.workers\n",
75+
"\n",
76+
"\n",
77+
"def main():\n",
78+
" #args = parse_args()\n",
79+
" with open('experiments/simulated/128x128_d256x3_adam_lr1e-3.yaml') as file:\n",
80+
" args = yaml.load(file, Loader=yaml.FullLoader)\n",
81+
" config = dotdict(args)\n",
82+
" \n",
83+
" #reset_config(config, args)\n",
84+
"\n",
85+
" logger, final_output_dir, tb_log_dir = create_logger(\n",
86+
" config, args.cfg, 'train')\n",
87+
"\n",
88+
" logger.info(pprint.pformat(args))\n",
89+
" logger.info(pprint.pformat(config))\n",
90+
"\n",
91+
" # cudnn related setting\n",
92+
" cudnn.benchmark = config.CUDNN.BENCHMARK\n",
93+
" torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC\n",
94+
" torch.backends.cudnn.enabled = config.CUDNN.ENABLED\n",
95+
"\n",
96+
" model = eval('models.'+config.MODEL.NAME+'.get_neuron_net')(\n",
97+
" config, is_train=True\n",
98+
" )\n",
99+
"\n",
100+
" # copy model file\n",
101+
" this_dir = os.path.dirname(__file__)\n",
102+
" shutil.copy2(\n",
103+
" os.path.join(this_dir, '../lib/models', config.MODEL.NAME + '.py'),\n",
104+
" final_output_dir)\n",
105+
"\n",
106+
" writer_dict = {\n",
107+
" 'writer': SummaryWriter(log_dir=tb_log_dir),\n",
108+
" 'train_global_steps': 0,\n",
109+
" 'valid_global_steps': 0,\n",
110+
" 'vis_global_steps': 0,\n",
111+
" }\n",
112+
"\n",
113+
" dump_input = torch.rand((config.TRAIN.BATCH_SIZE,\n",
114+
" 3,\n",
115+
" config.MODEL.IMAGE_SIZE[1],\n",
116+
" config.MODEL.IMAGE_SIZE[0]))\n",
117+
" writer_dict['writer'].add_graph(model, (dump_input, ), verbose=False)\n",
118+
"\n",
119+
" gpus = [int(i) for i in config.GPUS.split(',')]\n",
120+
" model = torch.nn.DataParallel(model, device_ids=gpus).cuda()\n",
121+
"\n",
122+
" # define loss function (criterion) and optimizer\n",
123+
" criterion = JointsMSELoss(\n",
124+
" use_target_weight=config.LOSS.USE_TARGET_WEIGHT\n",
125+
" ).cuda()\n",
126+
"\n",
127+
" optimizer = get_optimizer(config, model)\n",
128+
"\n",
129+
" lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(\n",
130+
" optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR\n",
131+
" )\n",
132+
"\n",
133+
" # Data loading code\n",
134+
" normalize = transforms.Normalize(mean=[0.1440, 0.1440, 0.1440],\n",
135+
" std=[19.0070, 19.0070, 19.0070])\n",
136+
" train_dataset = eval('dataset.'+config.DATASET.DATASET)(\n",
137+
" config,\n",
138+
" config.DATASET.ROOT,\n",
139+
" config.DATASET.TRAIN_SET,\n",
140+
" True,\n",
141+
" transforms.Compose([\n",
142+
" transforms.ToTensor(),\n",
143+
" normalize,\n",
144+
" ])\n",
145+
" )\n",
146+
" valid_dataset = eval('dataset.'+config.DATASET.DATASET)(\n",
147+
" config,\n",
148+
" config.DATASET.ROOT,\n",
149+
" config.DATASET.TEST_SET,\n",
150+
" False,\n",
151+
" transforms.Compose([\n",
152+
" transforms.ToTensor(),\n",
153+
" normalize,\n",
154+
" ])\n",
155+
" )\n",
156+
"\n",
157+
" train_loader = torch.utils.data.DataLoader(\n",
158+
" train_dataset,\n",
159+
" batch_size=config.TRAIN.BATCH_SIZE*len(gpus),\n",
160+
" shuffle=config.TRAIN.SHUFFLE,\n",
161+
" num_workers=config.WORKERS,\n",
162+
" pin_memory=True\n",
163+
" )\n",
164+
" valid_loader = torch.utils.data.DataLoader(\n",
165+
" valid_dataset,\n",
166+
" batch_size=config.TEST.BATCH_SIZE*len(gpus),\n",
167+
" shuffle=False,\n",
168+
" num_workers=config.WORKERS,\n",
169+
" pin_memory=True\n",
170+
" )\n",
171+
"\n",
172+
" best_perf = 0.0\n",
173+
" best_model = False\n",
174+
" for epoch in range(config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH):\n",
175+
" lr_scheduler.step()\n",
176+
"\n",
177+
" # train for one epoch\n",
178+
" train(config, train_loader, model, criterion, optimizer, epoch,\n",
179+
" final_output_dir, tb_log_dir, writer_dict)\n",
180+
"\n",
181+
"\n",
182+
" # evaluate on validation set\n",
183+
" perf_indicator = validate(config, valid_loader, valid_dataset, model,\n",
184+
" criterion, final_output_dir, tb_log_dir,\n",
185+
" writer_dict)\n",
186+
"\n",
187+
" if perf_indicator > best_perf:\n",
188+
" best_perf = perf_indicator\n",
189+
" best_model = True\n",
190+
" else:\n",
191+
" best_model = False\n",
192+
"\n",
193+
" logger.info('=> saving checkpoint to {}'.format(final_output_dir))\n",
194+
" save_checkpoint({\n",
195+
" 'epoch': epoch + 1,\n",
196+
" 'model': get_model_name(config),\n",
197+
" 'state_dict': model.state_dict(),\n",
198+
" 'perf': perf_indicator,\n",
199+
" 'optimizer': optimizer.state_dict(),\n",
200+
" }, best_model, final_output_dir)\n",
201+
"\n",
202+
" final_model_state_file = os.path.join(final_output_dir,\n",
203+
" 'final_state.pth.tar')\n",
204+
" logger.info('saving final model state to {}'.format(\n",
205+
" final_model_state_file))\n",
206+
" torch.save(model.module.state_dict(), final_model_state_file)\n",
207+
" writer_dict['writer'].close()\n"
208+
]
209+
},
210+
{
211+
"cell_type": "code",
212+
"execution_count": 33,
213+
"metadata": {},
214+
"outputs": [],
215+
"source": [
216+
"import yaml\n",
217+
" \n",
218+
"with open('experiments/simulated/128x128_d256x3_adam_lr1e-3.yaml') as file:\n",
219+
" args = yaml.load(file, Loader=yaml.FullLoader)\n",
220+
"\n",
221+
"class dotdict(dict):\n",
222+
" \"\"\"dot.notation access to dictionary attributes\"\"\"\n",
223+
" __getattr__ = dict.get\n",
224+
" __setattr__ = dict.__setitem__\n",
225+
" __delattr__ = dict.__delitem__\n",
226+
"\n",
227+
"args = dotdict(args)\n",
228+
"args"
229+
]
230+
},
231+
{
232+
"cell_type": "code",
233+
"execution_count": null,
234+
"metadata": {},
235+
"outputs": [],
236+
"source": []
237+
}
238+
],
239+
"metadata": {
240+
"kernelspec": {
241+
"display_name": "pyrooz",
242+
"language": "python",
243+
"name": "pyrooz"
244+
},
245+
"language_info": {
246+
"codemirror_mode": {
247+
"name": "ipython",
248+
"version": 3
249+
},
250+
"file_extension": ".py",
251+
"mimetype": "text/x-python",
252+
"name": "python",
253+
"nbconvert_exporter": "python",
254+
"pygments_lexer": "ipython3",
255+
"version": "3.6.9"
256+
}
257+
},
258+
"nbformat": 4,
259+
"nbformat_minor": 2
260+
}

0 commit comments

Comments
 (0)