Skip to content

Commit 67e6d01

Browse files
authored
Merge pull request #5 from zhmeishi/train
Train
2 parents df5ec60 + 2c77fa0 commit 67e6d01

8 files changed

+385
-21
lines changed

README.md

+17-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
23

34
# Deep Online Fused Video Stabilization
45

@@ -18,7 +19,6 @@ pip install -r requirements.txt --ignore-installed
1819
## Data Preparation
1920
Download sample video [here](https://drive.google.com/file/d/1nju9H8ohYZh6dGsdrQjQXFgfgkrFtkRi/view?usp=sharing).
2021
Uncompress the *video* folder under the *dvs* folder.
21-
2222
```
2323
python load_frame_sensor_data.py
2424
```
@@ -52,7 +52,21 @@ In *s_114_outdoor_running_trail_daytime.jpg*, the blue curve is the output of ou
5252
*s_114_outdoor_running_trail_daytime_stab_crop.mp4* is cropped stabilized video. Note, the cropped video is generated after running the metrics code.
5353

5454
## Training
55-
TBA
55+
Download dataset for training and test [here](https://storage.googleapis.com/dataset_release/all.zip).
56+
Uncompress *all.zip* and move *dataset_release* folder under the *dvs* folder.
57+
58+
Follow FlowNet2 Preparation Section.
59+
```
60+
python warp/read_write.py --dir_path ./dataset_release # video2frames
61+
cd flownet2
62+
bash run_release.sh # generate optical flow file for dataset
63+
```
64+
65+
Run training code.
66+
```
67+
python train.py
68+
```
69+
The model is saved in *checkpoint/stabilzation_train*.
5670

5771
## Citation
5872
If you use this code or dataset for your research, please cite our paper.
@@ -63,4 +77,4 @@ If you use this code or dataset for your research, please cite our paper.
6377
journal={arXiv preprint arXiv:2102.01279},
6478
year={2021}
6579
}
66-
```
80+
```

dvs/conf/stabilzation_train.yaml

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
data:
2+
exp: 'stabilzation_train'
3+
checkpoints_dir: './checkpoint'
4+
log: './log'
5+
data_dir: './dataset_release'
6+
use_cuda: true
7+
batch_size: 16
8+
resize_ratio: 0.25
9+
number_real: 10
10+
number_virtual: 2
11+
time_train: 2000 # ms
12+
sample_freq: 40 # ms
13+
channel_size: 1
14+
num_workers: 16 # num_workers for data_loader
15+
model:
16+
load_model: null
17+
cnn:
18+
activate_function: relu # sigmoid, relu, tanh, quadratic
19+
batch_norm: true
20+
gap: false
21+
layers:
22+
rnn:
23+
layers:
24+
- - 512
25+
- true
26+
- - 512
27+
- true
28+
fc:
29+
activate_function: relu
30+
batch_norm: false # (batch_norm and drop_out) is False
31+
layers:
32+
- - 256
33+
- true
34+
- - 4 # last layer should be equal to nr_class
35+
- true
36+
drop_out: 0
37+
train:
38+
optimizer: "adam" # adam or sgd
39+
momentum: 0.9 # for sgd
40+
decay_epoch: null
41+
epoch: 400
42+
snapshot: 2
43+
init_lr: 0.0001
44+
lr_decay: 0.5
45+
lr_step: 200 # if > 0 decay_epoch should be null
46+
seed: 1
47+
weight_decay: 0.0001
48+
clip_norm: False
49+
init: "xavier_uniform" # xavier_uniform or xavier_normal
50+
loss:
51+
follow: 10
52+
angle: 1
53+
smooth: 10 #10
54+
c2_smooth: 200 #20
55+
undefine: 2.0
56+
opt: 0.1
57+
stay: 0

dvs/dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def get_data_loader(cf, no_flo = False):
3434
def get_dataset(cf, no_flo = False):
3535
resize_ratio = cf["data"]["resize_ratio"]
3636
train_transform, test_transform = _data_transforms()
37-
train_path = os.path.join(cf["data"]["data_dir"], "train")
37+
train_path = os.path.join(cf["data"]["data_dir"], "training")
3838
test_path = os.path.join(cf["data"]["data_dir"], "test")
3939
if not os.path.exists(train_path):
4040
train_path = cf["data"]["data_dir"]

dvs/flownet2/run_release.sh

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/bin/bash
2+
python main.py --inference --model FlowNet2 --save_flow --inference_dataset Google \
3+
--inference_dataset_root ./../dataset_release/test \
4+
--resume ./FlowNet2_checkpoint.pth.tar \
5+
--inference_visualize
6+
7+
python main.py --inference --model FlowNet2 --save_flow --inference_dataset Google \
8+
--inference_dataset_root ./../dataset_release/training \
9+
--resume ./FlowNet2_checkpoint.pth.tar \
10+
--inference_visualize

dvs/load_frame_sensor_data.py

+1-15
Original file line numberDiff line numberDiff line change
@@ -91,21 +91,7 @@ def inference(cf, data_path, USE_CUDA):
9191
rotations_real, lens_offsets_real = get_rotations(data.frame[:data.length], data.gyro, data.ois, data.length)
9292
fig_path = os.path.join(data_path, video_name+"_real.jpg")
9393
visual_rotation(rotations_real, lens_offsets_real, None, None, None, None, fig_path)
94-
95-
# print("------Start Warping Video--------")
96-
# grid = get_grid(test_loader.dataset.static_options, \
97-
# data.frame[:data.length], data.gyro, data.ois, virtual_queue[:data.length,1:], no_shutter = False)
98-
99-
# grid_rm_shutter = get_grid(test_loader.dataset.static_options, \
100-
# data.frame[:data.length], data.gyro, np.zeros(data.ois.shape), virtual_queue[:data.length,1:], no_shutter = False)
101-
102-
# video_path = os.path.join(data_path, video_name+".mp4")
103-
# data_name = data_path.split("/")[-1]
104-
# save_path = os.path.join(data_path, video_name+"_no_ois.mp4")
105-
# warp_video(grid, video_path, save_path, losses = None)
106-
107-
# save_path = os.path.join(data_path, video_name+"_no_shutter.mp4")
108-
# warp_video(grid_rm_shutter, video_path, save_path, losses = None)
94+
10995
return
11096

11197
def main(args = None):

0 commit comments

Comments
 (0)