17
17
from .cnn import TwinsSelfAttentionLayer , TwinsCrossAttentionLayer , BasicEncoder
18
18
from .mlpmixer import MLPMixerLayer
19
19
from .convnext import ConvNextLayer
20
+ import time
20
21
21
22
from timm .models .layers import Mlp , DropPath , activations , to_2tuple , trunc_normal_
22
23
@@ -334,11 +335,19 @@ def corr(self, fmap1, fmap2):
334
335
return corr
335
336
336
337
def forward (self , img1 , img2 , data , context = None ):
337
- feat_s = self .feat_encoder (img1 )
338
- feat_t = self .feat_encoder (img2 )
338
+ # The original implementation
339
+ # feat_s = self.feat_encoder(img1)
340
+ # feat_t = self.feat_encoder(img2)
341
+ # feat_s = self.channel_convertor(feat_s)
342
+ # feat_t = self.channel_convertor(feat_t)
339
343
340
- feat_s = self .channel_convertor (feat_s )
341
- feat_t = self .channel_convertor (feat_t )
344
+ imgs = torch .cat ([img1 , img2 ], dim = 0 )
345
+ feats = self .feat_encoder (imgs )
346
+ feats = self .channel_convertor (feats )
347
+ B = feats .shape [0 ] // 2
348
+
349
+ feat_s = feats [:B ]
350
+ feat_t = feats [B :]
342
351
343
352
B , C , H , W = feat_s .shape
344
353
size = (H , W )
@@ -356,4 +365,4 @@ def forward(self, img1, img2, data, context=None):
356
365
cost_volume = self .corr (feat_s , feat_t )
357
366
x = self .cost_perceiver_encoder (cost_volume , data , context )
358
367
359
- return x
368
+ return x
0 commit comments