1
+ import sys
2
+ sys .path .append ('core' )
3
+
4
+ from PIL import Image
5
+ from glob import glob
6
+ import argparse
7
+ import os
8
+ import time
9
+ import numpy as np
10
+ import torch
11
+ import torch .nn .functional as F
12
+ import matplotlib .pyplot as plt
13
+ from configs .submission import get_cfg
14
+ from core .utils .misc import process_cfg
15
+ import datasets
16
+ from utils import flow_viz
17
+ from utils import frame_utils
18
+ import cv2
19
+ import math
20
+ import os .path as osp
21
+
22
+ from core .FlowFormer import build_flowformer
23
+
24
+ from utils .utils import InputPadder , forward_interpolate
25
+ import itertools
26
+
27
+ TRAIN_SIZE = [432 , 960 ]
28
+
29
+
30
+ def compute_grid_indices (image_shape , patch_size = TRAIN_SIZE , min_overlap = 20 ):
31
+ if min_overlap >= TRAIN_SIZE [0 ] or min_overlap >= TRAIN_SIZE [1 ]:
32
+ raise ValueError (
33
+ f"Overlap should be less than size of patch (got { min_overlap } "
34
+ f"for patch size { patch_size } )." )
35
+ if image_shape [0 ] == TRAIN_SIZE [0 ]:
36
+ hs = list (range (0 , image_shape [0 ], TRAIN_SIZE [0 ]))
37
+ else :
38
+ hs = list (range (0 , image_shape [0 ], TRAIN_SIZE [0 ] - min_overlap ))
39
+ if image_shape [1 ] == TRAIN_SIZE [1 ]:
40
+ ws = list (range (0 , image_shape [1 ], TRAIN_SIZE [1 ]))
41
+ else :
42
+ ws = list (range (0 , image_shape [1 ], TRAIN_SIZE [1 ] - min_overlap ))
43
+
44
+ # Make sure the final patch is flush with the image boundary
45
+ hs [- 1 ] = image_shape [0 ] - patch_size [0 ]
46
+ ws [- 1 ] = image_shape [1 ] - patch_size [1 ]
47
+ return [(h , w ) for h in hs for w in ws ]
48
+
49
+ def compute_weight (hws , image_shape , patch_size = TRAIN_SIZE , sigma = 1.0 , wtype = 'gaussian' ):
50
+ patch_num = len (hws )
51
+ h , w = torch .meshgrid (torch .arange (patch_size [0 ]), torch .arange (patch_size [1 ]))
52
+ h , w = h / float (patch_size [0 ]), w / float (patch_size [1 ])
53
+ c_h , c_w = 0.5 , 0.5
54
+ h , w = h - c_h , w - c_w
55
+ weights_hw = (h ** 2 + w ** 2 ) ** 0.5 / sigma
56
+ denorm = 1 / (sigma * math .sqrt (2 * math .pi ))
57
+ weights_hw = denorm * torch .exp (- 0.5 * (weights_hw ) ** 2 )
58
+
59
+ weights = torch .zeros (1 , patch_num , * image_shape )
60
+ for idx , (h , w ) in enumerate (hws ):
61
+ weights [:, idx , h :h + patch_size [0 ], w :w + patch_size [1 ]] = weights_hw
62
+ weights = weights .cuda ()
63
+ patch_weights = []
64
+ for idx , (h , w ) in enumerate (hws ):
65
+ patch_weights .append (weights [:, idx :idx + 1 , h :h + patch_size [0 ], w :w + patch_size [1 ]])
66
+
67
+ return patch_weights
68
+
69
+ def compute_flow (model , image1 , image2 , weights = None ):
70
+ print (f"computing flow..." )
71
+
72
+ image_size = image1 .shape [1 :]
73
+
74
+ hws = compute_grid_indices (image_size )
75
+ if weights is None :
76
+ weights = compute_weight (hws , image_size , sigma = 0.05 )
77
+
78
+ image1 , image2 = image1 [None ].cuda (), image2 [None ].cuda ()
79
+
80
+ flows = 0
81
+ flow_count = 0
82
+
83
+ for idx , (h , w ) in enumerate (hws ):
84
+ image1_tile = image1 [:, :, h :h + TRAIN_SIZE [0 ], w :w + TRAIN_SIZE [1 ]]
85
+ image2_tile = image2 [:, :, h :h + TRAIN_SIZE [0 ], w :w + TRAIN_SIZE [1 ]]
86
+ flow_pre , _ = model (image1_tile , image2_tile )
87
+ padding = (w , image_size [1 ]- w - TRAIN_SIZE [1 ], h , image_size [0 ]- h - TRAIN_SIZE [0 ], 0 , 0 )
88
+ flows += F .pad (flow_pre * weights [idx ], padding )
89
+ # flow_count += F.pad(weights, padding)
90
+ flow_count += F .pad (weights [idx ], padding )
91
+
92
+ flow_pre = flows / flow_count
93
+ flow = flow_pre [0 ].permute (1 , 2 , 0 ).cpu ().numpy ()
94
+
95
+ return flow , weights
96
+
97
+ def compute_adaptive_image_size (image_size ):
98
+ target_size = TRAIN_SIZE
99
+ scale0 = target_size [0 ] / image_size [0 ]
100
+ scale1 = target_size [1 ] / image_size [1 ]
101
+
102
+ if scale0 > scale1 :
103
+ scale = scale0
104
+ else :
105
+ scale = scale1
106
+
107
+ image_size = (int (image_size [1 ] * scale ), int (image_size [0 ] * scale ))
108
+
109
+ return image_size
110
+
111
+ def prepare_image (root_dir , viz_root_dir , fn1 , fn2 ):
112
+ print (f"preparing image..." )
113
+ print (f"root dir = { root_dir } , fn = { fn1 } " )
114
+
115
+ image1 = frame_utils .read_gen (osp .join (root_dir , fn1 ))
116
+ image2 = frame_utils .read_gen (osp .join (root_dir , fn2 ))
117
+ image1 = np .array (image1 ).astype (np .uint8 )[..., :3 ]
118
+ image2 = np .array (image2 ).astype (np .uint8 )[..., :3 ]
119
+ dsize = compute_adaptive_image_size (image1 .shape [0 :2 ])
120
+ image1 = cv2 .resize (image1 , dsize = dsize , interpolation = cv2 .INTER_CUBIC )
121
+ image2 = cv2 .resize (image2 , dsize = dsize , interpolation = cv2 .INTER_CUBIC )
122
+ image1 = torch .from_numpy (image1 ).permute (2 , 0 , 1 ).float ()
123
+ image2 = torch .from_numpy (image2 ).permute (2 , 0 , 1 ).float ()
124
+
125
+
126
+ dirname = osp .dirname (fn1 )
127
+ filename = osp .splitext (osp .basename (fn1 ))[0 ]
128
+
129
+ viz_dir = osp .join (viz_root_dir , dirname )
130
+ if not osp .exists (viz_dir ):
131
+ os .makedirs (viz_dir )
132
+
133
+ viz_fn = osp .join (viz_dir , filename + '.png' )
134
+
135
+ return image1 , image2 , viz_fn
136
+
137
+ def build_model ():
138
+ print (f"building model..." )
139
+ cfg = get_cfg ()
140
+ model = torch .nn .DataParallel (build_flowformer (cfg ))
141
+ model .load_state_dict (torch .load (cfg .model ))
142
+
143
+ model .cuda ()
144
+ model .eval ()
145
+
146
+ return model
147
+
148
+ def visualize_flow (root_dir , viz_root_dir , model , img_pairs ):
149
+ weights = None
150
+ for img_pair in img_pairs :
151
+ fn1 , fn2 = img_pair
152
+ print (f"processing { fn1 } , { fn2 } ..." )
153
+
154
+ image1 , image2 , viz_fn = prepare_image (root_dir , viz_root_dir , fn1 , fn2 )
155
+ flow , weights = compute_flow (model , image1 , image2 , weights )
156
+ flow_img = flow_viz .flow_to_image (flow )
157
+ cv2 .imwrite (viz_fn , flow_img [:, :, [2 ,1 ,0 ]])
158
+
159
+ def process_sintel (sintel_dir ):
160
+ img_pairs = []
161
+ for scene in os .listdir (sintel_dir ):
162
+ dirname = osp .join (sintel_dir , scene )
163
+ image_list = sorted (glob (osp .join (dirname , '*.png' )))
164
+ for i in range (len (image_list )- 1 ):
165
+ img_pairs .append ((image_list [i ], image_list [i + 1 ]))
166
+
167
+ return img_pairs
168
+
169
+ def generate_pairs (dirname , start_idx , end_idx ):
170
+ img_pairs = []
171
+ for idx in range (start_idx , end_idx ):
172
+ img1 = osp .join (dirname , f'{ idx :06} .png' )
173
+ img2 = osp .join (dirname , f'{ idx + 1 :06} .png' )
174
+ # img1 = f'{idx:06}.png'
175
+ # img2 = f'{idx+1:06}.png'
176
+ img_pairs .append ((img1 , img2 ))
177
+
178
+ return img_pairs
179
+
180
+ if __name__ == '__main__' :
181
+ parser = argparse .ArgumentParser ()
182
+ parser .add_argument ('--eval_type' , default = 'sintel' )
183
+ parser .add_argument ('--start_idx' , type = int , default = 1 )
184
+ parser .add_argument ('--end_idx' , type = int , default = 1200 )
185
+ parser .add_argument ('--root_dir' , default = '.' )
186
+ parser .add_argument ('--sintel_dir' , default = 'datasets/Sintel/test/clean' )
187
+ parser .add_argument ('--seq_dir' , default = 'demo_data/mihoyo' )
188
+ parser .add_argument ('--viz_root_dir' , default = 'viz_results' )
189
+
190
+ args = parser .parse_args ()
191
+
192
+ root_dir = args .root_dir
193
+ viz_root_dir = args .viz_root_dir
194
+
195
+ model = build_model ()
196
+
197
+ if args .eval_type == 'sintel' :
198
+ img_pairs = process_sintel (args .sintel_dir )
199
+ elif args .eval_type == 'seq' :
200
+ img_pairs = generate_pairs (args .seq_dir , args .start_idx , args .end_idx )
201
+ with torch .no_grad ():
202
+ visualize_flow (root_dir , viz_root_dir , model , img_pairs )
0 commit comments