Skip to content

Commit 6ade3db

Browse files
authored
Add files via upload
1 parent e0138ec commit 6ade3db

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+9352
-0
lines changed

benchmark/Vimeo90K.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import os
2+
import sys
3+
sys.path.append('.')
4+
import cv2
5+
import math
6+
import torch
7+
import numpy as np
8+
from torch.nn import functional as F
9+
from benchmark.pytorch_msssim import ssim_matlab as SSIM
10+
from models.WaveletVFI import WaveletVFI
11+
from thop import profile
12+
13+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14+
15+
def convert(param):
16+
return {k.replace("module.", ""): v for k, v in param.items() if "module." in k}
17+
18+
def PSNR(img_pred, img_gt):
19+
psnr = -10 * torch.log10(((img_pred - img_gt) * (img_pred - img_gt)).mean())
20+
return psnr
21+
22+
model = WaveletVFI()
23+
model.load_state_dict(convert(torch.load('./models/waveletvfi_latest.pth', map_location='cpu')))
24+
model.eval()
25+
model.to(device)
26+
27+
th = None
28+
29+
path = '/youtu_action_data/NTIRE/vimeo_triplet/'
30+
f = open(path + 'tri_testlist.txt', 'r')
31+
psnr_list = []
32+
ssim_list = []
33+
flops_list = []
34+
for i in f:
35+
name = str(i).strip()
36+
if(len(name) <= 1):
37+
continue
38+
print(path + 'sequences/' + name + '/im1.png')
39+
I0 = cv2.imread(path + 'sequences/' + name + '/im1.png')
40+
I1 = cv2.imread(path + 'sequences/' + name + '/im2.png')
41+
I2 = cv2.imread(path + 'sequences/' + name + '/im3.png')
42+
I0 = (torch.tensor(I0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
43+
I1 = (torch.tensor(I1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
44+
I2 = (torch.tensor(I2.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
45+
46+
macs, params, outputs = profile(model, inputs=(I0, I2, I1, False, th), verbose=False, output=True)
47+
I1_pred = outputs[0]
48+
49+
psnr = PSNR(I1_pred, I1).detach().cpu().numpy()
50+
ssim = SSIM(I1_pred, I1).detach().cpu().numpy()
51+
52+
psnr_list.append(psnr)
53+
ssim_list.append(ssim)
54+
flops_list.append(macs / 1e9)
55+
print("Avg PSNR: {:.3f} SSIM: {:.4f} FLOPs(G): {:.3f}".format(np.mean(psnr_list), np.mean(ssim_list), np.mean(flops_list)))

benchmark/pytorch_msssim/__init__.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from math import exp
4+
import numpy as np
5+
6+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7+
8+
9+
def gaussian(window_size, sigma):
10+
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
11+
return gauss/gauss.sum()
12+
13+
14+
def create_window(window_size, channel=1):
15+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
16+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
17+
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
18+
return window
19+
20+
21+
def create_window_3d(window_size, channel=1):
22+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
23+
_2D_window = _1D_window.mm(_1D_window.t())
24+
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
25+
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
26+
return window
27+
28+
29+
def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
30+
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
31+
if val_range is None:
32+
if torch.max(img1) > 128:
33+
max_val = 255
34+
else:
35+
max_val = 1
36+
37+
if torch.min(img1) < -0.5:
38+
min_val = -1
39+
else:
40+
min_val = 0
41+
L = max_val - min_val
42+
else:
43+
L = val_range
44+
45+
padd = 0
46+
(_, channel, height, width) = img1.size()
47+
if window is None:
48+
real_size = min(window_size, height, width)
49+
window = create_window(real_size, channel=channel).to(img1.device)
50+
51+
# mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
52+
# mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
53+
mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
54+
mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
55+
56+
mu1_sq = mu1.pow(2)
57+
mu2_sq = mu2.pow(2)
58+
mu1_mu2 = mu1 * mu2
59+
60+
sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq
61+
sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq
62+
sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2
63+
64+
C1 = (0.01 * L) ** 2
65+
C2 = (0.03 * L) ** 2
66+
67+
v1 = 2.0 * sigma12 + C2
68+
v2 = sigma1_sq + sigma2_sq + C2
69+
cs = torch.mean(v1 / v2) # contrast sensitivity
70+
71+
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
72+
73+
if size_average:
74+
ret = ssim_map.mean()
75+
else:
76+
ret = ssim_map.mean(1).mean(1).mean(1)
77+
78+
if full:
79+
return ret, cs
80+
return ret
81+
82+
83+
def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
84+
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
85+
if val_range is None:
86+
if torch.max(img1) > 128:
87+
max_val = 255
88+
else:
89+
max_val = 1
90+
91+
if torch.min(img1) < -0.5:
92+
min_val = -1
93+
else:
94+
min_val = 0
95+
L = max_val - min_val
96+
else:
97+
L = val_range
98+
99+
padd = 0
100+
(_, _, height, width) = img1.size()
101+
if window is None:
102+
real_size = min(window_size, height, width)
103+
window = create_window_3d(real_size, channel=1).to(img1.device)
104+
# Channel is set to 1 since we consider color images as volumetric images
105+
106+
img1 = img1.unsqueeze(1)
107+
img2 = img2.unsqueeze(1)
108+
109+
mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
110+
mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
111+
112+
mu1_sq = mu1.pow(2)
113+
mu2_sq = mu2.pow(2)
114+
mu1_mu2 = mu1 * mu2
115+
116+
sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq
117+
sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq
118+
sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2
119+
120+
C1 = (0.01 * L) ** 2
121+
C2 = (0.03 * L) ** 2
122+
123+
v1 = 2.0 * sigma12 + C2
124+
v2 = sigma1_sq + sigma2_sq + C2
125+
cs = torch.mean(v1 / v2) # contrast sensitivity
126+
127+
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
128+
129+
if size_average:
130+
ret = ssim_map.mean()
131+
else:
132+
ret = ssim_map.mean(1).mean(1).mean(1)
133+
134+
if full:
135+
return ret, cs
136+
return ret
137+
138+
139+
def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
140+
device = img1.device
141+
weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
142+
levels = weights.size()[0]
143+
mssim = []
144+
mcs = []
145+
for _ in range(levels):
146+
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
147+
mssim.append(sim)
148+
mcs.append(cs)
149+
150+
img1 = F.avg_pool2d(img1, (2, 2))
151+
img2 = F.avg_pool2d(img2, (2, 2))
152+
153+
mssim = torch.stack(mssim)
154+
mcs = torch.stack(mcs)
155+
156+
# Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
157+
if normalize:
158+
mssim = (mssim + 1) / 2
159+
mcs = (mcs + 1) / 2
160+
161+
pow1 = mcs ** weights
162+
pow2 = mssim ** weights
163+
# From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
164+
output = torch.prod(pow1[:-1] * pow2[-1])
165+
return output
166+
167+
168+
# Classes to re-use window
169+
class SSIM(torch.nn.Module):
170+
def __init__(self, window_size=11, size_average=True, val_range=None):
171+
super(SSIM, self).__init__()
172+
self.window_size = window_size
173+
self.size_average = size_average
174+
self.val_range = val_range
175+
176+
# Assume 3 channel for SSIM
177+
self.channel = 3
178+
self.window = create_window(window_size, channel=self.channel)
179+
180+
def forward(self, img1, img2):
181+
(_, channel, _, _) = img1.size()
182+
183+
if channel == self.channel and self.window.dtype == img1.dtype:
184+
window = self.window
185+
else:
186+
window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
187+
self.window = window
188+
self.channel = channel
189+
190+
_ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
191+
dssim = (1 - _ssim) / 2
192+
return dssim
193+
194+
195+
class MSSSIM(torch.nn.Module):
196+
def __init__(self, window_size=11, size_average=True, channel=3):
197+
super(MSSSIM, self).__init__()
198+
self.window_size = window_size
199+
self.size_average = size_average
200+
self.channel = channel
201+
202+
def forward(self, img1, img2):
203+
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)

models/IFNet.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from models.utils import bwarp
5+
6+
7+
def centralize(img0, img1):
8+
rgb_mean = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
9+
return img0 - rgb_mean, img1 - rgb_mean, rgb_mean
10+
11+
12+
def resize(x, scale_factor):
13+
return F.interpolate(x, scale_factor=scale_factor, mode='bilinear', align_corners=False)
14+
15+
16+
def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
17+
return nn.Sequential(
18+
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias),
19+
nn.LeakyReLU(negative_slope=0.1)
20+
)
21+
22+
23+
class Decoder(nn.Module):
24+
def __init__(self, in_channels, mid_channels, is_bottom=False):
25+
super(Decoder, self).__init__()
26+
self.is_bottom = is_bottom
27+
self.conv1 = convrelu(in_channels, mid_channels, 3, 1)
28+
self.conv2 = convrelu(mid_channels, mid_channels, 3, 1)
29+
self.conv3 = convrelu(mid_channels, mid_channels, 3, 1)
30+
self.conv4 = convrelu(mid_channels, mid_channels, 3, 1)
31+
self.conv5 = convrelu(mid_channels, mid_channels, 3, 1)
32+
self.conv6 = nn.ConvTranspose2d(mid_channels, 5, 4, 2, 1, bias=True)
33+
if self.is_bottom:
34+
self.classifier = nn.Sequential(
35+
convrelu(mid_channels, mid_channels, 3, 2),
36+
nn.AdaptiveAvgPool2d((1, 1)),
37+
nn.Flatten(1),
38+
nn.Linear(mid_channels, mid_channels, bias=True),
39+
nn.LeakyReLU(negative_slope=0.1),
40+
nn.Linear(mid_channels, 4, bias=True)
41+
)
42+
43+
def forward(self, x):
44+
out1 = self.conv1(x)
45+
out2 = self.conv2(out1)
46+
out3 = self.conv3(out2)
47+
out4 = self.conv4(out3)
48+
out5 = self.conv5(out4)
49+
out = self.conv6(out5)
50+
if self.is_bottom:
51+
class_prob_ = self.classifier(out5)
52+
return out, class_prob_
53+
else:
54+
return out
55+
56+
57+
class IFNet(nn.Module):
58+
def __init__(self):
59+
super(IFNet, self).__init__()
60+
self.pconv1 = nn.Sequential(convrelu(3, 32, 3, 2), convrelu(32, 32, 3, 1))
61+
self.pconv2 = nn.Sequential(convrelu(32, 64, 3, 2), convrelu(64, 64, 3, 1))
62+
self.pconv3 = nn.Sequential(convrelu(64, 96, 3, 2), convrelu(96, 96, 3, 1))
63+
self.pconv4 = nn.Sequential(convrelu(96, 128, 3, 2), convrelu(128, 128, 3, 1))
64+
65+
self.decoder4 = Decoder(256, 192, True)
66+
self.decoder3 = Decoder(197, 160, False)
67+
self.decoder2 = Decoder(133, 128, False)
68+
self.decoder1 = Decoder(69, 64, False)
69+
70+
for m in self.modules():
71+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
72+
nn.init.kaiming_normal_(m.weight)
73+
if m.bias is not None:
74+
nn.init.zeros_(m.bias)
75+
76+
def forward(self, img0, img1):
77+
img0, img1, _ = centralize(img0, img1)
78+
79+
f0_1 = self.pconv1(img0)
80+
f1_1 = self.pconv1(img1)
81+
f0_2 = self.pconv2(f0_1)
82+
f1_2 = self.pconv2(f1_1)
83+
f0_3 = self.pconv3(f0_2)
84+
f1_3 = self.pconv3(f1_2)
85+
f0_4 = self.pconv4(f0_3)
86+
f1_4 = self.pconv4(f1_3)
87+
88+
out4, class_prob_ = self.decoder4(torch.cat([f0_4, f1_4], 1))
89+
up_flow_t0_4 = out4[:, 0:2]
90+
up_flow_t1_4 = out4[:, 2:4]
91+
up_occ_t_4 = out4[:, 4:5]
92+
93+
f0_3_warp = bwarp(f0_3, up_flow_t0_4)
94+
f1_3_warp = bwarp(f1_3, up_flow_t1_4)
95+
out3 = self.decoder3(torch.cat([f0_3_warp, f1_3_warp, up_flow_t0_4, up_flow_t1_4, up_occ_t_4], 1))
96+
up_flow_t0_3 = out3[:, 0:2] + resize(up_flow_t0_4, 2.0) * 2.0
97+
up_flow_t1_3 = out3[:, 2:4] + resize(up_flow_t1_4, 2.0) * 2.0
98+
up_occ_t_3 = out3[:, 4:5] + resize(up_occ_t_4, 2.0)
99+
100+
f0_2_warp = bwarp(f0_2, up_flow_t0_3)
101+
f1_2_warp = bwarp(f1_2, up_flow_t1_3)
102+
out2 = self.decoder2(torch.cat([f0_2_warp, f1_2_warp, up_flow_t0_3, up_flow_t1_3, up_occ_t_3], 1))
103+
up_flow_t0_2 = out2[:, 0:2] + resize(up_flow_t0_3, 2.0) * 2.0
104+
up_flow_t1_2 = out2[:, 2:4] + resize(up_flow_t1_3, 2.0) * 2.0
105+
up_occ_t_2 = out2[:, 4:5] + resize(up_occ_t_3, 2.0)
106+
107+
f0_1_warp = bwarp(f0_1, up_flow_t0_2)
108+
f1_1_warp = bwarp(f1_1, up_flow_t1_2)
109+
out1 = self.decoder1(torch.cat([f0_1_warp, f1_1_warp, up_flow_t0_2, up_flow_t1_2, up_occ_t_2], 1))
110+
up_flow_t0_1 = out1[:, 0:2] + resize(up_flow_t0_2, 2.0) * 2.0
111+
up_flow_t1_1 = out1[:, 2:4] + resize(up_flow_t1_2, 2.0) * 2.0
112+
up_occ_t_1 = torch.sigmoid(out1[:, 4:5] + resize(up_occ_t_2, 2.0))
113+
114+
return up_flow_t0_1, up_flow_t1_1, up_occ_t_1, class_prob_

0 commit comments

Comments
 (0)