Skip to content

Commit 2975318

Browse files
committed
Update cvt.py
1 parent ee07e7c commit 2975318

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

timm/models/cvt.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -388,18 +388,18 @@ def __init__(
388388
mlp_layer: nn.Module = Mlp,
389389
mlp_ratio: float = 4.,
390390
mlp_act_layer: nn.Module = QuickGELU,
391-
use_cls_token: Tuple[bool, ...] = (False, False, True),
391+
use_cls_token: bool = True,
392392
drop_rate: float = 0.,
393393
) -> None:
394394
super().__init__()
395395
num_stages = len(dims)
396396
assert num_stages == len(depths) == len(embed_kernel_size) == len(embed_stride)
397-
assert num_stages == len(embed_padding) == len(num_heads) == len(use_cls_token)
397+
assert num_stages == len(embed_padding) == len(num_heads)
398398
self.num_classes = num_classes
399399
self.num_features = dims[-1]
400400
self.feature_info = []
401401

402-
self.use_cls_token = use_cls_token[-1]
402+
self.use_cls_token = use_cls_token
403403

404404
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
405405

@@ -437,7 +437,7 @@ def __init__(
437437
mlp_layer = mlp_layer,
438438
mlp_ratio = mlp_ratio,
439439
mlp_act_layer = mlp_act_layer,
440-
use_cls_token = use_cls_token[stage_idx],
440+
use_cls_token = use_cls_token and stage_idx == num_stages - 1,
441441
)
442442
in_chs = dim
443443
stages.append(stage)

0 commit comments

Comments
 (0)