@@ -108,17 +108,18 @@ def compute_adaptive_image_size(image_size):
108
108
109
109
return image_size
110
110
111
- def prepare_image (root_dir , viz_root_dir , fn1 , fn2 ):
111
+ def prepare_image (root_dir , viz_root_dir , fn1 , fn2 , keep_size ):
112
112
print (f"preparing image..." )
113
113
print (f"root dir = { root_dir } , fn = { fn1 } " )
114
114
115
115
image1 = frame_utils .read_gen (osp .join (root_dir , fn1 ))
116
116
image2 = frame_utils .read_gen (osp .join (root_dir , fn2 ))
117
117
image1 = np .array (image1 ).astype (np .uint8 )[..., :3 ]
118
118
image2 = np .array (image2 ).astype (np .uint8 )[..., :3 ]
119
- dsize = compute_adaptive_image_size (image1 .shape [0 :2 ])
120
- image1 = cv2 .resize (image1 , dsize = dsize , interpolation = cv2 .INTER_CUBIC )
121
- image2 = cv2 .resize (image2 , dsize = dsize , interpolation = cv2 .INTER_CUBIC )
119
+ if not keep_size :
120
+ dsize = compute_adaptive_image_size (image1 .shape [0 :2 ])
121
+ image1 = cv2 .resize (image1 , dsize = dsize , interpolation = cv2 .INTER_CUBIC )
122
+ image2 = cv2 .resize (image2 , dsize = dsize , interpolation = cv2 .INTER_CUBIC )
122
123
image1 = torch .from_numpy (image1 ).permute (2 , 0 , 1 ).float ()
123
124
image2 = torch .from_numpy (image2 ).permute (2 , 0 , 1 ).float ()
124
125
@@ -145,13 +146,13 @@ def build_model():
145
146
146
147
return model
147
148
148
- def visualize_flow (root_dir , viz_root_dir , model , img_pairs ):
149
+ def visualize_flow (root_dir , viz_root_dir , model , img_pairs , keep_size ):
149
150
weights = None
150
151
for img_pair in img_pairs :
151
152
fn1 , fn2 = img_pair
152
153
print (f"processing { fn1 } , { fn2 } ..." )
153
154
154
- image1 , image2 , viz_fn = prepare_image (root_dir , viz_root_dir , fn1 , fn2 )
155
+ image1 , image2 , viz_fn = prepare_image (root_dir , viz_root_dir , fn1 , fn2 , keep_size )
155
156
flow , weights = compute_flow (model , image1 , image2 , weights )
156
157
flow_img = flow_viz .flow_to_image (flow )
157
158
cv2 .imwrite (viz_fn , flow_img [:, :, [2 ,1 ,0 ]])
@@ -180,12 +181,13 @@ def generate_pairs(dirname, start_idx, end_idx):
180
181
if __name__ == '__main__' :
181
182
parser = argparse .ArgumentParser ()
182
183
parser .add_argument ('--eval_type' , default = 'sintel' )
183
- parser .add_argument ('--start_idx' , type = int , default = 1 )
184
- parser .add_argument ('--end_idx' , type = int , default = 1200 )
185
184
parser .add_argument ('--root_dir' , default = '.' )
186
185
parser .add_argument ('--sintel_dir' , default = 'datasets/Sintel/test/clean' )
187
186
parser .add_argument ('--seq_dir' , default = 'demo_data/mihoyo' )
187
+ parser .add_argument ('--start_idx' , type = int , default = 1 ) # starting index of the image sequence
188
+ parser .add_argument ('--end_idx' , type = int , default = 1200 ) # ending index of the image sequence
188
189
parser .add_argument ('--viz_root_dir' , default = 'viz_results' )
190
+ parser .add_argument ('--keep_size' , action = 'store_true' ) # keep the image size, or the image will be adaptively resized.
189
191
190
192
args = parser .parse_args ()
191
193
@@ -199,4 +201,4 @@ def generate_pairs(dirname, start_idx, end_idx):
199
201
elif args .eval_type == 'seq' :
200
202
img_pairs = generate_pairs (args .seq_dir , args .start_idx , args .end_idx )
201
203
with torch .no_grad ():
202
- visualize_flow (root_dir , viz_root_dir , model , img_pairs )
204
+ visualize_flow (root_dir , viz_root_dir , model , img_pairs , args . keep_size )
0 commit comments