Skip to content

Commit 5f67f15

Browse files
committed
feature inference with concatenated imags
1 parent 7ae7052 commit 5f67f15

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

core/FlowFormer/LatentCostFormer/encoder.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .cnn import TwinsSelfAttentionLayer, TwinsCrossAttentionLayer, BasicEncoder
1818
from .mlpmixer import MLPMixerLayer
1919
from .convnext import ConvNextLayer
20+
import time
2021

2122
from timm.models.layers import Mlp, DropPath, activations, to_2tuple, trunc_normal_
2223

@@ -334,11 +335,19 @@ def corr(self, fmap1, fmap2):
334335
return corr
335336

336337
def forward(self, img1, img2, data, context=None):
337-
feat_s = self.feat_encoder(img1)
338-
feat_t = self.feat_encoder(img2)
338+
# The original implementation
339+
# feat_s = self.feat_encoder(img1)
340+
# feat_t = self.feat_encoder(img2)
341+
# feat_s = self.channel_convertor(feat_s)
342+
# feat_t = self.channel_convertor(feat_t)
339343

340-
feat_s = self.channel_convertor(feat_s)
341-
feat_t = self.channel_convertor(feat_t)
344+
imgs = torch.cat([img1, img2], dim=0)
345+
feats = self.feat_encoder(imgs)
346+
feats = self.channel_convertor(feats)
347+
B = feats.shape[0] // 2
348+
349+
feat_s = feats[:B]
350+
feat_t = feats[B:]
342351

343352
B, C, H, W = feat_s.shape
344353
size = (H, W)
@@ -356,4 +365,4 @@ def forward(self, img1, img2, data, context=None):
356365
cost_volume = self.corr(feat_s, feat_t)
357366
x = self.cost_perceiver_encoder(cost_volume, data, context)
358367

359-
return x
368+
return x

0 commit comments

Comments
 (0)