Skip to content

Commit 609106f

Browse files
author
sfwang
committed
First commit.
0 parents  commit 609106f

File tree

17 files changed

+1693
-0
lines changed

17 files changed

+1693
-0
lines changed

README.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Source Detection for Neuron-tree Reconstruction
2+
3+
## Introduction
4+
Pytorch code for training and testing source detection models for neuron-tree reconstruction.
5+
6+
7+
## Overview
8+
- `source_detection/`: includes training and validation scripts.
9+
- `lib/`: contains data preparation, model definition, and some utility functions.
10+
- `experiments/`: contains `*.yaml` configuration files to run experiments.
11+
12+
13+
## Requirements
14+
The code is developed using python 3.7.1 on Ubuntu 16.04. NVIDIA GPUs ared needed to train and test.
15+
See [`requirements.txt`](requirements.txt) for other dependencies.
16+
17+
## Quick start
18+
### Installation
19+
1. Install pytorch >= v1.0.0 following [official instructions](https://pytorch.org/).
20+
2. Clone this repo, and we will call the directory that you cloned as `${ROOT}`
21+
3. Install dependencies.
22+
```
23+
pip install -r requirements.txt
24+
```
25+
4. Download pretrained ResNet-18 [model](https://download.pytorch.org/models/resnet18-5c106cde.pth)
26+
and put it under `${ROOT}/models/pytorch/imagenet/`
27+
28+
### Training with simulated data
29+
To train with simulated data, run:
30+
```
31+
CUDA_VISIBLE_DEVICES=$GPU_ID python source_detection/train.py --cfg experiments/simulated/128x128_d256x3_adam_lr1e-3.yaml
32+
```
33+
Model checkpoints and logs will be saved into outpu folder while tensorboard logs will be saved into log folder.
34+
35+
### Testing with simulated data
36+
To test with simulated data after training, run:
37+
```
38+
CUDA_VISIBLE_DEVICES=$GPU_ID python source_detection/validate.py --cfg experiments/simulated/128x128_d256x3_adam_lr1e-3.yaml
39+
```
40+
Tensorboard logs will be saved into log folder.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
GPUS: '0'
2+
DATA_DIR: ''
3+
OUTPUT_DIR: 'output'
4+
LOG_DIR: 'log'
5+
WORKERS: 4
6+
PRINT_FREQ: 100
7+
CUDNN:
8+
BENCHMARK: False
9+
DETERMINISTIC: True
10+
ENABLED: True
11+
DATASET:
12+
DATASET: simulation
13+
ROOT: None
14+
TEST_SET: valid
15+
TRAIN_SET: train
16+
# FLIP: true
17+
# ROT_FACTOR: 30
18+
# SCALE_FACTOR: 0.25
19+
MODEL:
20+
NAME: neuron_resnet
21+
PRETRAINED: 'models/pytorch/imagenet/resnet18-5c106cde.pth'
22+
IMAGE_SIZE:
23+
- 128
24+
- 128
25+
EXTRA:
26+
TARGET_TYPE: gaussian
27+
SIGMA: 2
28+
HEATMAP_SIZE:
29+
- 64
30+
- 64
31+
FINAL_CONV_KERNEL: 1
32+
DECONV_WITH_BIAS: false
33+
NUM_DECONV_LAYERS: 4
34+
NUM_DECONV_FILTERS:
35+
- 256
36+
- 256
37+
- 256
38+
- 256
39+
NUM_DECONV_KERNELS:
40+
- 4
41+
- 4
42+
- 4
43+
- 4
44+
NUM_LAYERS: 18
45+
LOSS:
46+
USE_TARGET_WEIGHT: False
47+
TRAIN:
48+
BATCH_SIZE: 32
49+
SHUFFLE: true
50+
BEGIN_EPOCH: 0
51+
END_EPOCH: 140
52+
RESUME: false
53+
OPTIMIZER: adam
54+
LR: 0.001
55+
LR_FACTOR: 0.1
56+
LR_STEP:
57+
- 90
58+
- 120
59+
WD: 0.0001
60+
GAMMA1: 0.99
61+
GAMMA2: 0.0
62+
MOMENTUM: 0.9
63+
NESTEROV: false
64+
TEST:
65+
BATCH_SIZE: 32
66+
FLIP_TEST: false
67+
MODEL_FILE: 'output/simulation/neuron_resnet_18/128x128_d256x3_adam_lr1e-3/model_best.pth.tar'
68+
DEBUG:
69+
DEBUG: false
70+
SAVE_BATCH_IMAGES_GT: true
71+
SAVE_BATCH_IMAGES_PRED: true
72+
SAVE_HEATMAPS_GT: true
73+
SAVE_HEATMAPS_PRED: true

lib/core/config.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# ------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft
3+
# Licensed under the MIT License.
4+
# Written by Bin Xiao ([email protected])
5+
# ------------------------------------------------------------------------------
6+
7+
from __future__ import absolute_import
8+
from __future__ import division
9+
from __future__ import print_function
10+
11+
import os
12+
import yaml
13+
14+
import numpy as np
15+
from easydict import EasyDict as edict
16+
17+
18+
config = edict()
19+
20+
config.OUTPUT_DIR = ''
21+
config.LOG_DIR = ''
22+
config.DATA_DIR = ''
23+
config.GPUS = '0'
24+
config.WORKERS = 4
25+
config.PRINT_FREQ = 20
26+
27+
# Cudnn related params
28+
config.CUDNN = edict()
29+
config.CUDNN.BENCHMARK = True
30+
config.CUDNN.DETERMINISTIC = False
31+
config.CUDNN.ENABLED = True
32+
33+
# neuron_resnet related params
34+
NEURON_RESNET = edict()
35+
NEURON_RESNET.NUM_LAYERS = 50
36+
NEURON_RESNET.DECONV_WITH_BIAS = False
37+
NEURON_RESNET.NUM_DECONV_LAYERS = 3
38+
NEURON_RESNET.NUM_DECONV_FILTERS = [256, 256, 256]
39+
NEURON_RESNET.NUM_DECONV_KERNELS = [4, 4, 4]
40+
NEURON_RESNET.FINAL_CONV_KERNEL = 1
41+
NEURON_RESNET.TARGET_TYPE = 'gaussian'
42+
NEURON_RESNET.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32
43+
NEURON_RESNET.SIGMA = 2
44+
45+
MODEL_EXTRAS = {
46+
'neuron_resnet': NEURON_RESNET,
47+
}
48+
49+
# common params for NETWORK
50+
config.MODEL = edict()
51+
config.MODEL.NAME = 'neuron_resnet'
52+
config.MODEL.INIT_WEIGHTS = True
53+
config.MODEL.INIT_DECONVS = False
54+
config.MODEL.INTEGRAL_REG = False
55+
config.MODEL.INTEGRAL_LOSS_TYPE = 'L1'
56+
config.MODEL.PRETRAINED = ''
57+
config.MODEL.IMAGE_SIZE = [128, 128] # width * height, ex: 192 * 256
58+
config.MODEL.INPUT_SIZE = [128, 128]
59+
config.MODEL.OUTPUT_SIZE = [64, 64]
60+
config.MODEL.DEPTH_DIM = 1
61+
config.MODEL.EXTRA = MODEL_EXTRAS[config.MODEL.NAME]
62+
63+
config.MODEL.STYLE = 'pytorch'
64+
65+
config.MODEL.DEPTH = 0.05
66+
config.MODEL.VAR_NOISE = 1.0
67+
68+
config.LOSS = edict()
69+
config.LOSS.USE_TARGET_WEIGHT = True
70+
71+
# DATASET related params
72+
config.DATASET = edict()
73+
config.DATASET.ROOT = ''
74+
config.DATASET.DATASET = 'simulation'
75+
config.DATASET.TRAIN_SET = 'train'
76+
config.DATASET.TEST_SET = 'valid'
77+
config.DATASET.DATA_FORMAT = 'jpg'
78+
79+
# training data augmentation
80+
config.DATASET.FLIP = True
81+
config.DATASET.SCALE_FACTOR = 0.25
82+
config.DATASET.ROT_FACTOR = 30
83+
config.DATASET.PAD_BORDER = True
84+
85+
# train
86+
config.TRAIN = edict()
87+
88+
config.TRAIN.LR_FACTOR = 0.1
89+
config.TRAIN.LR_STEP = [90, 110]
90+
config.TRAIN.LR = 0.001
91+
92+
config.TRAIN.OPTIMIZER = 'adam'
93+
config.TRAIN.MOMENTUM = 0.9
94+
config.TRAIN.WD = 0.0001
95+
config.TRAIN.NESTEROV = False
96+
config.TRAIN.GAMMA1 = 0.99
97+
config.TRAIN.GAMMA2 = 0.0
98+
99+
config.TRAIN.BEGIN_EPOCH = 0
100+
config.TRAIN.END_EPOCH = 140
101+
102+
config.TRAIN.RESUME = False
103+
config.TRAIN.CHECKPOINT = ''
104+
105+
config.TRAIN.BATCH_SIZE = 32
106+
config.TRAIN.SHUFFLE = True
107+
108+
config.TRAIN.NUM_SAMPLES = 1e5
109+
110+
# testing
111+
config.TEST = edict()
112+
113+
# size of images for each device
114+
config.TEST.BATCH_SIZE = 32
115+
# Test Model Epoch
116+
config.TEST.FLIP_TEST = False
117+
config.TEST.POST_PROCESS = True
118+
config.TEST.SHIFT_HEATMAP = True
119+
120+
config.TEST.USE_GT_BBOX = False
121+
# nms
122+
config.TEST.OKS_THRE = 0.5
123+
config.TEST.IN_VIS_THRE = 0.0
124+
config.TEST.COCO_BBOX_FILE = ''
125+
config.TEST.BBOX_THRE = 1.0
126+
config.TEST.MODEL_FILE = ''
127+
config.TEST.IMAGE_THRE = 0.0
128+
config.TEST.NMS_THRE = 1.0
129+
130+
config.TEST.NUM_SAMPLES = 5e3
131+
132+
# debug
133+
config.DEBUG = edict()
134+
config.DEBUG.DEBUG = False
135+
config.DEBUG.SAVE_BATCH_IMAGES_GT = False
136+
config.DEBUG.SAVE_BATCH_IMAGES_PRED = False
137+
config.DEBUG.SAVE_HEATMAPS_GT = False
138+
config.DEBUG.SAVE_HEATMAPS_PRED = False
139+
140+
141+
def _update_dict(k, v):
142+
if k == 'DATASET':
143+
if 'MEAN' in v and v['MEAN']:
144+
v['MEAN'] = np.array([eval(x) if isinstance(x, str) else x
145+
for x in v['MEAN']])
146+
if 'STD' in v and v['STD']:
147+
v['STD'] = np.array([eval(x) if isinstance(x, str) else x
148+
for x in v['STD']])
149+
if k == 'MODEL':
150+
if 'EXTRA' in v and 'HEATMAP_SIZE' in v['EXTRA']:
151+
if isinstance(v['EXTRA']['HEATMAP_SIZE'], int):
152+
v['EXTRA']['HEATMAP_SIZE'] = np.array(
153+
[v['EXTRA']['HEATMAP_SIZE'], v['EXTRA']['HEATMAP_SIZE']])
154+
else:
155+
v['EXTRA']['HEATMAP_SIZE'] = np.array(
156+
v['EXTRA']['HEATMAP_SIZE'])
157+
if 'IMAGE_SIZE' in v:
158+
if isinstance(v['IMAGE_SIZE'], int):
159+
v['IMAGE_SIZE'] = np.array([v['IMAGE_SIZE'], v['IMAGE_SIZE']])
160+
else:
161+
v['IMAGE_SIZE'] = np.array(v['IMAGE_SIZE'])
162+
for vk, vv in v.items():
163+
if vk in config[k]:
164+
config[k][vk] = vv
165+
else:
166+
raise ValueError("{}.{} not exist in config.py".format(k, vk))
167+
168+
169+
def update_config(config_file):
170+
exp_config = None
171+
with open(config_file) as f:
172+
exp_config = edict(yaml.load(f))
173+
for k, v in exp_config.items():
174+
if k in config:
175+
if isinstance(v, dict):
176+
_update_dict(k, v)
177+
else:
178+
if k == 'SCALES':
179+
config[k][0] = (tuple(v))
180+
else:
181+
config[k] = v
182+
else:
183+
raise ValueError("{} not exist in config.py".format(k))
184+
185+
186+
def gen_config(config_file):
187+
cfg = dict(config)
188+
for k, v in cfg.items():
189+
if isinstance(v, edict):
190+
cfg[k] = dict(v)
191+
192+
with open(config_file, 'w') as f:
193+
yaml.dump(dict(cfg), f, default_flow_style=False)
194+
195+
196+
def update_dir(model_dir, log_dir, data_dir):
197+
if model_dir:
198+
config.OUTPUT_DIR = model_dir
199+
200+
if log_dir:
201+
config.LOG_DIR = log_dir
202+
203+
if data_dir:
204+
config.DATA_DIR = data_dir
205+
206+
config.DATASET.ROOT = os.path.join(
207+
config.DATA_DIR, config.DATASET.ROOT)
208+
209+
config.TEST.COCO_BBOX_FILE = os.path.join(
210+
config.DATA_DIR, config.TEST.COCO_BBOX_FILE)
211+
212+
config.MODEL.PRETRAINED = os.path.join(
213+
config.DATA_DIR, config.MODEL.PRETRAINED)
214+
215+
216+
def get_model_name(cfg):
217+
name = cfg.MODEL.NAME
218+
full_name = cfg.MODEL.NAME
219+
extra = cfg.MODEL.EXTRA
220+
if name in ['neuron_resnet']:
221+
name = '{model}_{num_layers}'.format(
222+
model=name,
223+
num_layers=extra.NUM_LAYERS)
224+
deconv_suffix = ''.join(
225+
'd{}'.format(num_filters)
226+
for num_filters in extra.NUM_DECONV_FILTERS)
227+
full_name = '{height}x{width}_{name}_{deconv_suffix}'.format(
228+
height=cfg.MODEL.IMAGE_SIZE[1],
229+
width=cfg.MODEL.IMAGE_SIZE[0],
230+
name=name,
231+
deconv_suffix=deconv_suffix)
232+
elif name not in ['keypoint_mlp']:
233+
raise ValueError('Unkown model: {}'.format(cfg.MODEL))
234+
235+
return name, full_name
236+
237+
238+
if __name__ == '__main__':
239+
import sys
240+
gen_config(sys.argv[1])

0 commit comments

Comments
 (0)