Skip to content

Support positional encoding interpolation in Zoo models #601

@korotaS

Description

@korotaS

Hi! I am using this example to train a ConcatSiamese model and if I use an extractor from example (vits16_dino) - it runs ok but if I use different model, for example vitb32_unicom I get an error. Here is an example for reproducing:

device = 'cpu'
extractor = ViTUnicomExtractor.from_pretrained("vitb32_unicom").to(device)
transforms, _ = get_transforms_for_pretrained("vitb32_unicom")
pairwise_model = ConcatSiamese(extractor=extractor, mlp_hidden_dims=[100], device=device)

out = pairwise_model(x1=torch.rand(2, 3, 224, 224), x2=torch.rand(2, 3, 224, 224))

And here is the last traceback item:

File /home/korotas/projects/open-metric-learning/oml/models/vit_unicom/external/vision_transformer.py:181, in VisionTransformer.forward_features(self, x)
    179 B = x.shape[0]
    180 x = self.patch_embed(x)
--> 181 x = x + self.pos_embed
    182 for func in self.blocks:
    183     x = func(x)

RuntimeError: The size of tensor a (98) must match the size of tensor b (49) at non-singleton dimension 1

I think the problem is that in vits16_dino there is an interpolation before positional embedding, so the tensor is downsampled into preferred shape:

def interpolate_pos_encoding(self, x, w: int, h: int):
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
class_pos_embed = self.pos_embed[:, 0]
patch_pos_embed = self.pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_embed.patch_size
h0 = h // self.patch_embed.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode="bicubic",
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

But in vitb32_unicom there is no such interpolation, so the input tensor is 2 times bigger than positional encoding expects, so we need to manually replace some layers that depend on number of patches (model.pos_embed and model.feature[0]).

I think that this information needs to be added in docs or perhaps handled in ViTUnicomExtractor or in ConcatSiamese with some kind of warning. Also, this error reproduces with all other ViTUnicomExtractor and ViTCLIPExtractor models.

Metadata

Metadata

Assignees

Type

No type

Projects

Status

To do

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions