-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathSoftSplatModel.py
143 lines (123 loc) · 5.45 KB
/
SoftSplatModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import torch.nn as nn
from OpticalFlow.PWCNet import PWCNet
from softsplat import Softsplat
from GridNet import GridNet
from UNet import SmallUNet
from torch.nn.functional import interpolate, grid_sample
from einops import repeat
class BackWarp(nn.Module):
def __init__(self, clip=True):
super(BackWarp, self).__init__()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.clip = clip
def forward(self, img, flow):
b, c, h, w = img.shape
gridY, gridX = torch.meshgrid(torch.arange(h), torch.arange(w))
gridX, gridY = gridX.to(self.device), gridY.to(self.device)
u = flow[:, 0] # W
v = flow[:, 1] # H
x = repeat(gridX, 'h w -> b h w', b=b).float() + u
y = repeat(gridY, 'h w -> b h w', b=b).float() + v
# normalize
x = (x / w) * 2 - 1
y = (y / h) * 2 - 1
# stacking X and Y
grid = torch.stack((x, y), dim=-1)
# Sample pixels using bilinear interpolation.
if self.clip:
output = grid_sample(img, grid, mode='bilinear', align_corners=True, padding_mode='border')
else:
output = grid_sample(img, grid, mode='bilinear', align_corners=True)
return output
class SoftSplatBaseline(nn.Module):
def __init__(self, predefined_z=False, act=nn.PReLU):
super(SoftSplatBaseline, self).__init__()
self.flow_predictor = PWCNet()
self.flow_predictor.load_state_dict(torch.load('./OpticalFlow/pwc-checkpoint.pt'))
self.fwarp = Softsplat()
self.bwarp = BackWarp(clip=False)
self.feature_pyramid = nn.ModuleList([
nn.Sequential(
nn.Conv2d(3, 32, 3, 1, 1),
act(),
nn.Conv2d(32, 32, 3, 1, 1),
act()
),
nn.Sequential(
nn.Conv2d(32, 64, 3, 2, 1),
act(),
nn.Conv2d(64, 64, 3, 1, 1),
act()
),
nn.Sequential(
nn.Conv2d(64, 96, 3, 2, 1),
act(),
nn.Conv2d(96, 96, 3, 1, 1),
act()
),
])
self.synth_net = GridNet(dim=32, act=act)
self.predefined_z = predefined_z
if predefined_z:
self.alpha = nn.Parameter(-torch.ones(1))
else:
self.v_net = SmallUNet()
def instance_norm(self, x):
x0, x1 = x.chunk(2, dim=0)
x = torch.stack([x0, x1], dim=2)
mean, std = x.view(x.shape[0], x.shape[1], -1).mean(dim=-1).view(x.shape[0], x.shape[1], 1, 1), x.view(x.shape[0], x.shape[1], -1).std(dim=-1).view(x.shape[0], x.shape[1], 1, 1) + 1e-8
x0, x1 = (x0 - mean) / std, (x1 - mean) / std
return torch.cat([x0, x1], dim=0)
def forward(self, x, target_t):
target_t = target_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
fr0, fr1 = x[:, :, 0], x[:, :, 1]
flow = self.flow_predictor(torch.cat([fr0, fr1], dim=0), torch.cat([fr1, fr0], dim=0))
# preprocess via instance normalization
with torch.no_grad():
mean, std = x.view(x.shape[0], 3, -1).mean(dim=-1).view(x.shape[0], 3, 1, 1), x.view(x.shape[0], 3, -1).std(dim=-1).view(x.shape[0], 3, 1, 1) + 1e-8
fr0, fr1 = (fr0 - mean) / std, (fr1 - mean) / std
f_lv = torch.cat([fr0, fr1], dim=0)
pyramid = [f_lv]
for feat_extractor_lv in self.feature_pyramid:
f_lv = feat_extractor_lv(f_lv)
pyramid.append(self.instance_norm(f_lv))
# Z importance metric
brightness_diff = torch.abs(self.bwarp(torch.cat([fr1, fr0], dim=0), flow) - torch.cat([fr0, fr1], dim=0))
if self.predefined_z:
z = self.alpha * torch.sum(brightness_diff, dim=1, keepdim=True)
else:
z = self.v_net(torch.cat([torch.cat([fr0, fr1]), -brightness_diff], dim=1))
# warping
n_lv = len(pyramid)
warped_feat_pyramid = []
for lv in range(n_lv):
f_lv = pyramid[lv]
scale_factor = f_lv.shape[-1] / flow.shape[-1]
flow_lv = interpolate(flow, scale_factor=scale_factor, mode='bilinear', align_corners=False) * scale_factor
flow01, flow10 = flow_lv.chunk(2, dim=0)
flow0t, flow1t = flow01 * target_t, flow10 * (1 - target_t)
flowt = torch.cat([flow0t, flow1t], dim=0)
z_lv = interpolate(z, scale_factor=scale_factor, mode='bilinear', align_corners=False)
warped_f_lv = self.fwarp(f_lv, flowt, z_lv)
warped_feat_pyramid.append(warped_f_lv)
concat_warped_feat_pyramid = []
for feat_lv in warped_feat_pyramid:
feat0_lv, feat1_lv = feat_lv.chunk(2, dim=0)
feat_lv = torch.cat([feat0_lv, feat1_lv], dim=1)
concat_warped_feat_pyramid.append(feat_lv)
output = self.synth_net(concat_warped_feat_pyramid)
output = (output * std) + mean # rollback normalization
if not self.training:
output = torch.clamp(output, 0, 1)
return output
if __name__ == '__main__':
'''
Example Usage
'''
frame0frame1 = torch.randn([1, 3, 2, 448, 256]).cuda() # batch size 1, 3 RGB channels, 2 frame input, H x W of 448 x 256
target_t = torch.tensor([0.5]).cuda()
model = SoftSplatBaseline().cuda()
model.load_state_dict(torch.load('./ckpt/SoftSplatBaseline_Vimeo.pth'))
with torch.no_grad():
output = model(frame0frame1, target_t)