-
Notifications
You must be signed in to change notification settings - Fork 73
Description
Hi! I am using this example to train a ConcatSiamese model and if I use an extractor from example (vits16_dino) - it runs ok but if I use different model, for example vitb32_unicom I get an error. Here is an example for reproducing:
device = 'cpu'
extractor = ViTUnicomExtractor.from_pretrained("vitb32_unicom").to(device)
transforms, _ = get_transforms_for_pretrained("vitb32_unicom")
pairwise_model = ConcatSiamese(extractor=extractor, mlp_hidden_dims=[100], device=device)
out = pairwise_model(x1=torch.rand(2, 3, 224, 224), x2=torch.rand(2, 3, 224, 224))And here is the last traceback item:
File /home/korotas/projects/open-metric-learning/oml/models/vit_unicom/external/vision_transformer.py:181, in VisionTransformer.forward_features(self, x)
179 B = x.shape[0]
180 x = self.patch_embed(x)
--> 181 x = x + self.pos_embed
182 for func in self.blocks:
183 x = func(x)
RuntimeError: The size of tensor a (98) must match the size of tensor b (49) at non-singleton dimension 1I think the problem is that in vits16_dino there is an interpolation before positional embedding, so the tensor is downsampled into preferred shape:
open-metric-learning/oml/models/vit_dino/external/vision_transformer.py
Lines 260 to 280 in 05842c8
| def interpolate_pos_encoding(self, x, w: int, h: int): | |
| npatch = x.shape[1] - 1 | |
| N = self.pos_embed.shape[1] - 1 | |
| if npatch == N and w == h: | |
| return self.pos_embed | |
| class_pos_embed = self.pos_embed[:, 0] | |
| patch_pos_embed = self.pos_embed[:, 1:] | |
| dim = x.shape[-1] | |
| w0 = w // self.patch_embed.patch_size | |
| h0 = h // self.patch_embed.patch_size | |
| # we add a small number to avoid floating point error in the interpolation | |
| # see discussion at https://github.com/facebookresearch/dino/issues/8 | |
| w0, h0 = w0 + 0.1, h0 + 0.1 | |
| patch_pos_embed = nn.functional.interpolate( | |
| patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), | |
| scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), | |
| mode="bicubic", | |
| ) | |
| assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] | |
| patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) | |
| return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) |
But in vitb32_unicom there is no such interpolation, so the input tensor is 2 times bigger than positional encoding expects, so we need to manually replace some layers that depend on number of patches (model.pos_embed and model.feature[0]).
I think that this information needs to be added in docs or perhaps handled in ViTUnicomExtractor or in ConcatSiamese with some kind of warning. Also, this error reproduces with all other ViTUnicomExtractor and ViTCLIPExtractor models.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status