-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathdpt-original-to-smp.py
100 lines (85 loc) · 4.02 KB
/
dpt-original-to-smp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import segmentation_models_pytorch as smp
import torch
MODEL_WEIGHTS_PATH = r"C:\Users\vedan\Downloads\dpt_large-ade20k-b12dca68.pt"
HF_HUB_PATH = "vedantdalimkar/DPT"
if __name__ == "__main__":
smp_model = smp.DPT(encoder_name="tu-vit_large_patch16_384", classes=150)
dpt_model_dict = torch.load(MODEL_WEIGHTS_PATH)
for layer_index in range(0, 4):
for param in [
"running_mean",
"running_var",
"num_batches_tracked",
"weight",
"bias",
]:
for block_index in [1, 2]:
for bn_index in [1, 2]:
# Assigning weights of 4th fusion layer of original model to 1st layer of SMP DPT model,
# Assigning weights of 3rd fusion layer of original model to 2nd layer of SMP DPT model ...
# and so on ...
# This is because order of calling fusion layers is reversed in original DPT implementation
dpt_model_dict[
f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.batch_norm_{bn_index}.{param}"
] = dpt_model_dict.pop(
f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.bn{bn_index}.{param}"
)
if param in ["weight", "bias"]:
if param == "weight":
for block_index in [1, 2]:
for conv_index in [1, 2]:
dpt_model_dict[
f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.conv_{conv_index}.{param}"
] = dpt_model_dict.pop(
f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.conv{conv_index}.{param}"
)
dpt_model_dict[
f"decoder.reassemble_blocks.{layer_index}.project_to_feature_dim.{param}"
] = dpt_model_dict.pop(f"scratch.layer{layer_index + 1}_rn.{param}")
dpt_model_dict[
f"decoder.fusion_blocks.{layer_index}.project.{param}"
] = dpt_model_dict.pop(
f"scratch.refinenet{4 - layer_index}.out_conv.{param}"
)
dpt_model_dict[
f"decoder.readout_blocks.{layer_index}.project.0.{param}"
] = dpt_model_dict.pop(
f"pretrained.act_postprocess{layer_index + 1}.0.project.0.{param}"
)
dpt_model_dict[
f"decoder.reassemble_blocks.{layer_index}.project_to_out_channel.{param}"
] = dpt_model_dict.pop(
f"pretrained.act_postprocess{layer_index + 1}.3.{param}"
)
if layer_index != 2:
dpt_model_dict[
f"decoder.reassemble_blocks.{layer_index}.upsample.{param}"
] = dpt_model_dict.pop(
f"pretrained.act_postprocess{layer_index + 1}.4.{param}"
)
# Changing state dict keys for segmentation head
dpt_model_dict = {
(
"segmentation_head.head" + name[len("scratch.output_conv") :]
if name.startswith("scratch.output_conv")
else name
): parameter
for name, parameter in dpt_model_dict.items()
}
# Changing state dict keys for encoder layers
dpt_model_dict = {
(
"encoder.model" + name[len("pretrained.model") :]
if name.startswith("pretrained.model")
else name
): parameter
for name, parameter in dpt_model_dict.items()
}
# Removing keys,value pairs associated with auxiliary head
dpt_model_dict = {
name: parameter
for name, parameter in dpt_model_dict.items()
if not name.startswith("auxlayer")
}
smp_model.load_state_dict(dpt_model_dict, strict=True)
smp_model.save_pretrained(HF_HUB_PATH, push_to_hub=True)