Skip to content

Commit 6ccb7d6

Browse files
authored
Merge pull request #2111 from jamesljlster/enhance_vit_get_intermediate_layers
Vision Transformer (ViT) get_intermediate_layers: enhanced to support dynamic image size and saved computational costs from unused blocks
2 parents 70ccf00 + db06b56 commit 6ccb7d6

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

timm/models/vision_transformer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -635,13 +635,14 @@ def _intermediate_layers(
635635
) -> List[torch.Tensor]:
636636
outputs, num_blocks = [], len(self.blocks)
637637
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
638+
last_index_to_take = max(take_indices)
638639

639640
# forward pass
640641
x = self.patch_embed(x)
641642
x = self._pos_embed(x)
642643
x = self.patch_drop(x)
643644
x = self.norm_pre(x)
644-
for i, blk in enumerate(self.blocks):
645+
for i, blk in enumerate(self.blocks[: last_index_to_take + 1]):
645646
x = blk(x)
646647
if i in take_indices:
647648
outputs.append(x)
@@ -667,9 +668,12 @@ def get_intermediate_layers(
667668
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
668669

669670
if reshape:
670-
grid_size = self.patch_embed.grid_size
671+
patch_size = self.patch_embed.patch_size
672+
batch, _, height, width = x.size()
671673
outputs = [
672-
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
674+
out.reshape(batch, int(math.ceil(height / patch_size[0])), int(math.ceil(width / patch_size[1])), -1)
675+
.permute(0, 3, 1, 2)
676+
.contiguous()
673677
for out in outputs
674678
]
675679

0 commit comments

Comments
 (0)