Skip to content

Commit 7ae7052

Browse files
committed
evaluation and visualization code
1 parent 3b1edc7 commit 7ae7052

File tree

3 files changed

+238
-1
lines changed

3 files changed

+238
-1
lines changed

README.md

+36
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,42 @@ We provide [models](https://drive.google.com/drive/folders/1K2dcWxaqOLiQ3PoqRdok
6666
├── sintel.pth
6767
├── kitti.pth
6868
```
69+
70+
## Evaluation
71+
The model to be evaluated is assigned by the `_CN.model` in the config file.
72+
73+
Evaluating the model on the Sintel training set and the KITTI training set. The corresponding config file is `configs/things_eval.py`.
74+
```Shell
75+
python evaluate_FlowFormer_tile.py --eval sintel_validation
76+
python evaluate_FlowFormer_tile.py --eval kitti_validation
77+
```
78+
Generating the submission for the Sintel and KITTI benchmarks. The corresponding config file is `configs/submission.py`.
79+
```Shell
80+
python evaluate_FlowFormer_tile.py --eval sintel_submission
81+
python evaluate_FlowFormer_tile.py --eval kitti_submission
82+
```
83+
Visualizing the sintel dataset:
84+
```Shell
85+
python visualize_flow.py --eval_type sintel
86+
```
87+
Visualizing an image sequence extracted from a video:
88+
```Shell
89+
python visualize_flow.py --eval_type seq
90+
```
91+
The default image sequence format is:
92+
```Shell
93+
├── demo_data
94+
├── mihoyo
95+
├── 000001.png
96+
├── 000002.png
97+
├── 000003.png
98+
.
99+
.
100+
.
101+
├── 001000.png
102+
```
103+
104+
69105
## License
70106
FlowFormer is released under the Apache License
71107

evaluate_FlowFormer_tile.py

-1
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,6 @@ def validate_sintel(model, sigma=0.05):
300300

301301
if __name__ == '__main__':
302302
parser = argparse.ArgumentParser()
303-
parser.add_argument('--dataset', help="dataset for evaluation")
304303
parser.add_argument('--eval', help='eval benchmark')
305304
args = parser.parse_args()
306305

visualize_flow.py

+202
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import sys
2+
sys.path.append('core')
3+
4+
from PIL import Image
5+
from glob import glob
6+
import argparse
7+
import os
8+
import time
9+
import numpy as np
10+
import torch
11+
import torch.nn.functional as F
12+
import matplotlib.pyplot as plt
13+
from configs.submission import get_cfg
14+
from core.utils.misc import process_cfg
15+
import datasets
16+
from utils import flow_viz
17+
from utils import frame_utils
18+
import cv2
19+
import math
20+
import os.path as osp
21+
22+
from core.FlowFormer import build_flowformer
23+
24+
from utils.utils import InputPadder, forward_interpolate
25+
import itertools
26+
27+
TRAIN_SIZE = [432, 960]
28+
29+
30+
def compute_grid_indices(image_shape, patch_size=TRAIN_SIZE, min_overlap=20):
31+
if min_overlap >= TRAIN_SIZE[0] or min_overlap >= TRAIN_SIZE[1]:
32+
raise ValueError(
33+
f"Overlap should be less than size of patch (got {min_overlap}"
34+
f"for patch size {patch_size}).")
35+
if image_shape[0] == TRAIN_SIZE[0]:
36+
hs = list(range(0, image_shape[0], TRAIN_SIZE[0]))
37+
else:
38+
hs = list(range(0, image_shape[0], TRAIN_SIZE[0] - min_overlap))
39+
if image_shape[1] == TRAIN_SIZE[1]:
40+
ws = list(range(0, image_shape[1], TRAIN_SIZE[1]))
41+
else:
42+
ws = list(range(0, image_shape[1], TRAIN_SIZE[1] - min_overlap))
43+
44+
# Make sure the final patch is flush with the image boundary
45+
hs[-1] = image_shape[0] - patch_size[0]
46+
ws[-1] = image_shape[1] - patch_size[1]
47+
return [(h, w) for h in hs for w in ws]
48+
49+
def compute_weight(hws, image_shape, patch_size=TRAIN_SIZE, sigma=1.0, wtype='gaussian'):
50+
patch_num = len(hws)
51+
h, w = torch.meshgrid(torch.arange(patch_size[0]), torch.arange(patch_size[1]))
52+
h, w = h / float(patch_size[0]), w / float(patch_size[1])
53+
c_h, c_w = 0.5, 0.5
54+
h, w = h - c_h, w - c_w
55+
weights_hw = (h ** 2 + w ** 2) ** 0.5 / sigma
56+
denorm = 1 / (sigma * math.sqrt(2 * math.pi))
57+
weights_hw = denorm * torch.exp(-0.5 * (weights_hw) ** 2)
58+
59+
weights = torch.zeros(1, patch_num, *image_shape)
60+
for idx, (h, w) in enumerate(hws):
61+
weights[:, idx, h:h+patch_size[0], w:w+patch_size[1]] = weights_hw
62+
weights = weights.cuda()
63+
patch_weights = []
64+
for idx, (h, w) in enumerate(hws):
65+
patch_weights.append(weights[:, idx:idx+1, h:h+patch_size[0], w:w+patch_size[1]])
66+
67+
return patch_weights
68+
69+
def compute_flow(model, image1, image2, weights=None):
70+
print(f"computing flow...")
71+
72+
image_size = image1.shape[1:]
73+
74+
hws = compute_grid_indices(image_size)
75+
if weights is None:
76+
weights = compute_weight(hws, image_size, sigma=0.05)
77+
78+
image1, image2 = image1[None].cuda(), image2[None].cuda()
79+
80+
flows = 0
81+
flow_count = 0
82+
83+
for idx, (h, w) in enumerate(hws):
84+
image1_tile = image1[:, :, h:h+TRAIN_SIZE[0], w:w+TRAIN_SIZE[1]]
85+
image2_tile = image2[:, :, h:h+TRAIN_SIZE[0], w:w+TRAIN_SIZE[1]]
86+
flow_pre, _ = model(image1_tile, image2_tile)
87+
padding = (w, image_size[1]-w-TRAIN_SIZE[1], h, image_size[0]-h-TRAIN_SIZE[0], 0, 0)
88+
flows += F.pad(flow_pre * weights[idx], padding)
89+
# flow_count += F.pad(weights, padding)
90+
flow_count += F.pad(weights[idx], padding)
91+
92+
flow_pre = flows / flow_count
93+
flow = flow_pre[0].permute(1, 2, 0).cpu().numpy()
94+
95+
return flow, weights
96+
97+
def compute_adaptive_image_size(image_size):
98+
target_size = TRAIN_SIZE
99+
scale0 = target_size[0] / image_size[0]
100+
scale1 = target_size[1] / image_size[1]
101+
102+
if scale0 > scale1:
103+
scale = scale0
104+
else:
105+
scale = scale1
106+
107+
image_size = (int(image_size[1] * scale), int(image_size[0] * scale))
108+
109+
return image_size
110+
111+
def prepare_image(root_dir, viz_root_dir, fn1, fn2):
112+
print(f"preparing image...")
113+
print(f"root dir = {root_dir}, fn = {fn1}")
114+
115+
image1 = frame_utils.read_gen(osp.join(root_dir, fn1))
116+
image2 = frame_utils.read_gen(osp.join(root_dir, fn2))
117+
image1 = np.array(image1).astype(np.uint8)[..., :3]
118+
image2 = np.array(image2).astype(np.uint8)[..., :3]
119+
dsize = compute_adaptive_image_size(image1.shape[0:2])
120+
image1 = cv2.resize(image1, dsize=dsize, interpolation=cv2.INTER_CUBIC)
121+
image2 = cv2.resize(image2, dsize=dsize, interpolation=cv2.INTER_CUBIC)
122+
image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
123+
image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
124+
125+
126+
dirname = osp.dirname(fn1)
127+
filename = osp.splitext(osp.basename(fn1))[0]
128+
129+
viz_dir = osp.join(viz_root_dir, dirname)
130+
if not osp.exists(viz_dir):
131+
os.makedirs(viz_dir)
132+
133+
viz_fn = osp.join(viz_dir, filename + '.png')
134+
135+
return image1, image2, viz_fn
136+
137+
def build_model():
138+
print(f"building model...")
139+
cfg = get_cfg()
140+
model = torch.nn.DataParallel(build_flowformer(cfg))
141+
model.load_state_dict(torch.load(cfg.model))
142+
143+
model.cuda()
144+
model.eval()
145+
146+
return model
147+
148+
def visualize_flow(root_dir, viz_root_dir, model, img_pairs):
149+
weights = None
150+
for img_pair in img_pairs:
151+
fn1, fn2 = img_pair
152+
print(f"processing {fn1}, {fn2}...")
153+
154+
image1, image2, viz_fn = prepare_image(root_dir, viz_root_dir, fn1, fn2)
155+
flow, weights = compute_flow(model, image1, image2, weights)
156+
flow_img = flow_viz.flow_to_image(flow)
157+
cv2.imwrite(viz_fn, flow_img[:, :, [2,1,0]])
158+
159+
def process_sintel(sintel_dir):
160+
img_pairs = []
161+
for scene in os.listdir(sintel_dir):
162+
dirname = osp.join(sintel_dir, scene)
163+
image_list = sorted(glob(osp.join(dirname, '*.png')))
164+
for i in range(len(image_list)-1):
165+
img_pairs.append((image_list[i], image_list[i+1]))
166+
167+
return img_pairs
168+
169+
def generate_pairs(dirname, start_idx, end_idx):
170+
img_pairs = []
171+
for idx in range(start_idx, end_idx):
172+
img1 = osp.join(dirname, f'{idx:06}.png')
173+
img2 = osp.join(dirname, f'{idx+1:06}.png')
174+
# img1 = f'{idx:06}.png'
175+
# img2 = f'{idx+1:06}.png'
176+
img_pairs.append((img1, img2))
177+
178+
return img_pairs
179+
180+
if __name__ == '__main__':
181+
parser = argparse.ArgumentParser()
182+
parser.add_argument('--eval_type', default='sintel')
183+
parser.add_argument('--start_idx', type=int, default=1)
184+
parser.add_argument('--end_idx', type=int, default=1200)
185+
parser.add_argument('--root_dir', default='.')
186+
parser.add_argument('--sintel_dir', default='datasets/Sintel/test/clean')
187+
parser.add_argument('--seq_dir', default='demo_data/mihoyo')
188+
parser.add_argument('--viz_root_dir', default='viz_results')
189+
190+
args = parser.parse_args()
191+
192+
root_dir = args.root_dir
193+
viz_root_dir = args.viz_root_dir
194+
195+
model = build_model()
196+
197+
if args.eval_type == 'sintel':
198+
img_pairs = process_sintel(args.sintel_dir)
199+
elif args.eval_type == 'seq':
200+
img_pairs = generate_pairs(args.seq_dir, args.start_idx, args.end_idx)
201+
with torch.no_grad():
202+
visualize_flow(root_dir, viz_root_dir, model, img_pairs)

0 commit comments

Comments
 (0)