Skip to content

Commit 452129b

Browse files
author
wbw
committed
CHANGES
1 parent c6a67ea commit 452129b

20 files changed

+783
-109
lines changed

configs.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,18 @@
44

55
import argparse
66
parser = argparse.ArgumentParser(description="PyTorch implementation of cpt")
7-
parser.add_argument('--dataset', type=str, default="CUB200")
7+
parser.add_argument('--dataset', type=str, default="ImageNet")
88
parser.add_argument('--dataset_dir', type=str, default="/media/wbw/a7f02863-b441-49d0-b546-6ef6fefbbc7e")
99
parser.add_argument('--output_dir', type=str, default="saved_model")
1010
# ========================= Model Configs ==========================
11-
parser.add_argument('--num_classes', default=50, help='category for classification')
12-
parser.add_argument('--num_cpt', default=50, help='number of the concept')
11+
parser.add_argument('--num_classes', default=20, help='category for classification')
12+
parser.add_argument('--num_cpt', default=10, help='number of the concept')
1313
parser.add_argument('--base_model', default="resnet18", type=str)
1414
parser.add_argument('--img_size', default=224, help='size for input image')
1515
parser.add_argument('--pre_train', default=False, type=bool,
1616
help='whether pre-train the model')
17+
parser.add_argument('--aug', default=True, type=bool,
18+
help='whether use augmentation')
1719
parser.add_argument('--act_type', default="sigmoid", help='the activation for the slot attention')
1820
parser.add_argument('--num_retrieval', default=50, help='number of the top retrieval images')
1921
parser.add_argument('--weight_att', default=False, help='using fc weight for att visualization')
@@ -24,10 +26,10 @@
2426

2527
# ========================= Learning Configs ==========================
2628
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
27-
parser.add_argument('--lr', default=0.0005, type=float)
29+
parser.add_argument('--lr', default=0.0001, type=float)
2830
parser.add_argument('--batch_size', default=128, type=int)
29-
parser.add_argument('--epoch', default=40, type=int)
30-
parser.add_argument('--lr_drop', default=20, type=float, nargs="+",
31+
parser.add_argument('--epoch', default=80, type=int)
32+
parser.add_argument('--lr_drop', default=60, type=float, nargs="+",
3133
metavar='LRSteps', help='epochs to decay learning rate by 10')
3234
# ========================= Machine Configs ==========================
3335
parser.add_argument('--num_workers', default=4, type=int)

cpt_compare.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from torchvision import datasets, transforms
2+
from model.model_main import MainModel
3+
from configs import parser
4+
import torch
5+
import os
6+
from PIL import Image
7+
import numpy as np
8+
from utils import apply_colormap_on_image
9+
from loaders.get_loader import load_all_imgs, get_transform
10+
from tools import for_retrival, attention_estimation
11+
import h5py
12+
from draw_tools import draw_bar, draw_plot
13+
import shutil
14+
import torch.nn.functional as F
15+
from tools import crop_center, shot_game
16+
import cv2
17+
import copy
18+
from captum.metrics import infidelity
19+
20+
21+
shutil.rmtree('vis/', ignore_errors=True)
22+
shutil.rmtree('vis_pp/', ignore_errors=True)
23+
os.makedirs('vis/', exist_ok=True)
24+
os.makedirs('vis_pp/', exist_ok=True)
25+
np.set_printoptions(suppress=True)
26+
27+
28+
def main():
29+
# load all imgs
30+
imgs_database, labels_database, imgs_val, labels_val, cat = load_all_imgs(args)
31+
print("All category:")
32+
print(cat)
33+
transform = get_transform(args)["val"]
34+
35+
# load model and weights
36+
model = MainModel(args, vis=True)
37+
device = torch.device("cuda:0")
38+
model.to(device)
39+
name = f"{args.dataset}_{args.base_model}_cls{args.num_classes}_cpt{args.num_cpt}_" + f"{'use_slot_' + args.act_type + '_' + args.cpt_activation if not args.pre_train else 'no_slot'}.pt"
40+
print(name)
41+
checkpoint = torch.load(os.path.join(args.output_dir, name), map_location="cuda:0")
42+
model.load_state_dict(checkpoint, strict=True)
43+
model.eval()
44+
record = []
45+
record_Dauc = []
46+
47+
for i in range(len(imgs_val)):
48+
print(i)
49+
model.vis = True
50+
data = imgs_val[i]
51+
print(data)
52+
label = labels_val[i]
53+
# print(i)
54+
# print(data)
55+
56+
image_orl = Image.open(data).convert('RGB').resize([256, 256], resample=Image.BILINEAR)
57+
if image_orl.mode == 'L':
58+
image_orl = image_orl.convert('RGB')
59+
image_orl = crop_center(image_orl, 224, 224)
60+
imggg = transform(image_orl).unsqueeze(0).to(device)
61+
w = model.state_dict()["cls.weight"][label]
62+
w2 = w.clone()
63+
w2 = torch.relu(w2)
64+
cpt, pred, att, update = model(imggg, w2)
65+
66+
pred = F.softmax(pred, dim=-1)
67+
pred_label = torch.argmax(pred).item()
68+
if pred_label != label:
69+
print("predict error")
70+
continue
71+
72+
# print("------------")
73+
# print("The Model Prediction is: ", pred_label)
74+
# print("True is", label)
75+
76+
# for id in range(args.num_cpt):
77+
# slot_image = np.array(Image.open(f'vis/0_slot_{id}.png'), dtype=np.uint8)
78+
# heatmap_only, heatmap_on_image = apply_colormap_on_image(img_orl2, slot_image, 'jet')
79+
# heatmap_on_image.save("vis/" + f'0_slot_mask_{id}.png')
80+
81+
# slot_image = np.array(Image.open(f'vis/overall.png'), dtype=np.uint8)
82+
# heatmap_only, heatmap_on_image = apply_colormap_on_image(img_orl2, slot_image, 'jet')
83+
# heatmap_on_image.save("vis/" + f'overall_mask.png')
84+
mask = cv2.imread("vis/" + f'overall.png', cv2.IMREAD_UNCHANGED) / 255
85+
hitted, segment = shot_game(mask, data)
86+
if hitted is None:
87+
continue
88+
record.append(hitted)
89+
90+
record_p = [pred[0][pred_label].item()]
91+
mask1 = mask.flatten()
92+
ids = np.argsort(-mask1, axis=0)
93+
model.vis = False
94+
for j in range(1, 101, 1):
95+
thresh = mask1[ids[j * 501]]
96+
mask_use = copy.deepcopy(mask)
97+
mask_use[mask_use >= thresh] = 0
98+
mask_use[mask_use != 0] = 1
99+
100+
mask_use = torch.from_numpy(mask_use).to(device, torch.float32)
101+
new_img = imggg * mask_use
102+
cpt, pred, att, update = model(new_img, None, None)
103+
output_c = F.softmax(pred, dim=-1)
104+
record_p.append(output_c[0][pred_label].item())
105+
record_p = np.array(record_p)
106+
record_p = (record_p - np.min(record_p)) / (np.max(record_p) - np.min(record_p))
107+
# print(record_p)
108+
print(record_p.mean())
109+
record_Dauc.append(record_p.mean())
110+
111+
print(np.mean(np.array(record)))
112+
print(record)
113+
print(np.mean(np.array(record_Dauc)))
114+
115+
116+
if __name__ == '__main__':
117+
args = parser.parse_args()
118+
args.pre_train = False
119+
main()

draw_graph/ablation.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
4+
5+
plt.figure(figsize=(36, 10), dpi=80)
6+
plt.rcParams['font.family'] = 'serif'
7+
plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']
8+
plt.rcParams['axes.linewidth'] = 3
9+
10+
ax1 = plt.subplot(251)
11+
ax2 = plt.subplot(252)
12+
ax3 = plt.subplot(253)
13+
ax4 = plt.subplot(254)
14+
ax5 = plt.subplot(255)
15+
16+
ax6 = plt.subplot(256)
17+
ax7 = plt.subplot(257)
18+
ax8 = plt.subplot(258)
19+
ax9 = plt.subplot(259)
20+
ax10 = plt.subplot(2, 5, 10)
21+
22+
linewidth = 5
23+
font = 30
24+
25+
markers = ['o', 's', '^']
26+
colors = ['#edb03d', "#4dbeeb", "#77ac41"]
27+
index = [0, 1, 2, 3, 4]
28+
index3 = [0, 1, 2, 3, 4]
29+
x_cpt_mnist = ["5", "10", "20", "50", "100"]
30+
x_cpt_bird = ["20", "50", "100", "200", "300"]
31+
x_1 = [0, 0.1, 1, 2, 5]
32+
x_2 = [0, 0.1, 1, 5, 10]
33+
x_3 = [0, 0.1, 1, 5, 10]
34+
x_4 = [0, 0.1, 1, 2, 5]
35+
36+
ax1.set_xticks(index)
37+
ax1.set_xticklabels(x_cpt_mnist)
38+
ax2.set_xticks(index)
39+
ax2.set_xticklabels(x_1)
40+
ax3.set_xticks(index)
41+
ax3.set_xticklabels(x_2)
42+
ax4.set_xticks(index)
43+
ax4.set_xticklabels(x_3)
44+
ax5.set_xticks(index3)
45+
ax5.set_xticklabels(x_4)
46+
acc_1 = [0.857, 0.922, 0.962, 0.995, 0.992]
47+
inter_1 = [0.732, 0.791, 0.930, 0.965, 0.988]
48+
exter_1 = [0.531, 0.513, 0.618, 0.906, 0.960]
49+
50+
acc_2 = [0.940, 0.962, 0.922, 0.886, 0.780]
51+
inter_2 = [0.921, 0.930, 0.930, 0.938, 0.889]
52+
exter_2 = [0.731, 0.723, 0.718, 0.736, 0.849]
53+
54+
acc_3 = [0.960, 0.956, 0.962, 0.955, 0.963]
55+
inter_3 = [0.930, 0.926, 0.930, 0.945, 0.970]
56+
exter_3 = [0.711, 0.710, 0.718, 0.715, 0.789]
57+
58+
acc_4 = [0.957, 0.967, 0.962, 0.948, 0.900]
59+
inter_4 = [0.931, 0.935, 0.930, 0.938, 0.949]
60+
exter_4 = [0.719, 0.713, 0.718, 0.716, 0.700]
61+
62+
acc_5 = [0.927, 0.949, 0.962, 0.948, 0.952]
63+
inter_5 = [0.981, 0.935, 0.930, 0.948, 0.949]
64+
exter_5 = [0.979, 0.923, 0.718, 0.746, 0.750]
65+
66+
ax6.set_xticks(index)
67+
ax6.set_xticklabels(x_cpt_bird)
68+
ax7.set_xticks(index)
69+
ax7.set_xticklabels(x_1)
70+
ax8.set_xticks(index)
71+
ax8.set_xticklabels(x_2)
72+
ax9.set_xticks(index)
73+
ax9.set_xticklabels(x_3)
74+
ax10.set_xticks(index3)
75+
ax10.set_xticklabels(x_4)
76+
77+
acc_6 = [0.566, 0.668, 0.675, 0.680, 0.676]
78+
inter_6 = [0.702, 0.850, 0.930, 0.961, 0.965]
79+
exter_6 = [0.631, 0.353, 0.245, 0.248, 0.361]
80+
81+
acc_7 = [0.650, 0.668, 0.642, 0.588, 0.128]
82+
inter_7 = [0.752, 0.850, 0.880, 0.810, 0.402]
83+
exter_7 = [0.601, 0.353, 0.332, 0.371, 0.853]
84+
85+
acc_8 = [0.704, 0.685, 0.668, 0.662, 0.634]
86+
inter_8 = [0.652, 0.812, 0.850, 0.921, 0.960]
87+
exter_8 = [0.381, 0.373, 0.345, 0.398, 0.401]
88+
89+
acc_9 = [0.620, 0.656, 0.668, 0.660, 0.652]
90+
inter_9 = [0.812, 0.823, 0.850, 0.881, 0.880]
91+
exter_9 = [0.671, 0.583, 0.345, 0.298, 0.291]
92+
93+
acc_10 = [0.670, 0.656, 0.668, 0.660, 0.652]
94+
inter_10 = [0.782, 0.799, 0.890, 0.891, 0.900]
95+
exter_10 = [0.361, 0.343, 0.265, 0.268, 0.261]
96+
97+
ax1.axis(ymin=0, ymax=1)
98+
ax2.axis(ymin=0, ymax=1)
99+
ax3.axis(ymin=0, ymax=1)
100+
ax4.axis(ymin=0, ymax=1)
101+
ax5.axis(ymin=0, ymax=1)
102+
ax1.set_yticks(np.linspace(0, 1, 2, endpoint=True))
103+
ax2.set_yticks([])
104+
ax3.set_yticks([])
105+
ax4.set_yticks([])
106+
ax5.set_yticks([])
107+
108+
ax6.axis(ymin=0, ymax=1)
109+
ax7.axis(ymin=0, ymax=1)
110+
ax8.axis(ymin=0, ymax=1)
111+
ax9.axis(ymin=0, ymax=1)
112+
ax10.axis(ymin=0, ymax=1)
113+
ax6.set_yticks(np.linspace(0, 1, 2, endpoint=True))
114+
ax7.set_yticks([])
115+
ax8.set_yticks([])
116+
ax9.set_yticks([])
117+
ax10.set_yticks([])
118+
119+
ax1.tick_params(labelsize=font+5)
120+
ax2.tick_params(labelsize=font+5)
121+
ax3.tick_params(labelsize=font+5)
122+
ax4.tick_params(labelsize=font+5)
123+
ax5.tick_params(labelsize=font+5)
124+
ax6.tick_params(labelsize=font+5)
125+
ax7.tick_params(labelsize=font+5)
126+
ax8.tick_params(labelsize=font+5)
127+
ax9.tick_params(labelsize=font+5)
128+
ax10.tick_params(labelsize=font+5)
129+
130+
131+
size_1 = 10
132+
133+
ax1.set_ylabel("MNIST", fontsize=font+size_1)
134+
ax6.set_ylabel("CUB200", fontsize=font+size_1)
135+
ax6.set_xlabel("k", fontsize=font+size_1+5)
136+
ax7.set_xlabel("$\lambda_{qua}$", fontsize=font+size_1)
137+
ax8.set_xlabel("$\lambda_{con}$", fontsize=font+size_1)
138+
ax9.set_xlabel("$\lambda_{dis}$", fontsize=font+size_1)
139+
ax10.set_xlabel("$\lambda_R$", fontsize=font+size_1)
140+
141+
ax1.plot(index, acc_1, marker=markers[0], markevery=1, markersize=15, color=colors[0], linewidth=linewidth, linestyle="-")
142+
ax1.plot(index, inter_1, marker=markers[1], markevery=1, markersize=15, color=colors[1], linewidth=linewidth, linestyle="-")
143+
ax1.plot(index, exter_1, marker=markers[2], markevery=1, markersize=15, color=colors[2], linewidth=linewidth, linestyle="-")
144+
ax2.plot(index, acc_2, marker=markers[0], markevery=1, markersize=15, color=colors[0], linewidth=linewidth, linestyle="-")
145+
ax2.plot(index, inter_2, marker=markers[1], markevery=1, markersize=15, color=colors[1], linewidth=linewidth, linestyle="-")
146+
ax2.plot(index, exter_2, marker=markers[2], markevery=1, markersize=15, color=colors[2], linewidth=linewidth, linestyle="-")
147+
ax3.plot(index, acc_3, marker=markers[0], markevery=1, markersize=15, color=colors[0], linewidth=linewidth, linestyle="-")
148+
ax3.plot(index, inter_3, marker=markers[1], markevery=1, markersize=15, color=colors[1], linewidth=linewidth, linestyle="-")
149+
ax3.plot(index, exter_3, marker=markers[2], markevery=1, markersize=15, color=colors[2], linewidth=linewidth, linestyle="-")
150+
ax4.plot(index, acc_4, marker=markers[0], markevery=1, markersize=15, color=colors[0], linewidth=linewidth, linestyle="-")
151+
ax4.plot(index, inter_4, marker=markers[1], markevery=1, markersize=15, color=colors[1], linewidth=linewidth, linestyle="-")
152+
ax4.plot(index, exter_4, marker=markers[2], markevery=1, markersize=15, color=colors[2], linewidth=linewidth, linestyle="-")
153+
ax5.plot(index, acc_5, marker=markers[0], markevery=1, markersize=15, color=colors[0], linewidth=linewidth, linestyle="-")
154+
ax5.plot(index, inter_5, marker=markers[1], markevery=1, markersize=15, color=colors[1], linewidth=linewidth, linestyle="-")
155+
ax5.plot(index, exter_5, marker=markers[2], markevery=1, markersize=15, color=colors[2], linewidth=linewidth, linestyle="-")
156+
157+
158+
ax6.plot(index, acc_6, marker=markers[0], markevery=1, markersize=15, color=colors[0], linewidth=linewidth, linestyle="-")
159+
ax6.plot(index, inter_6, marker=markers[1], markevery=1, markersize=15, color=colors[1], linewidth=linewidth, linestyle="-")
160+
ax6.plot(index, exter_6, marker=markers[2], markevery=1, markersize=15, color=colors[2], linewidth=linewidth, linestyle="-")
161+
ax7.plot(index, acc_7, marker=markers[0], markevery=1, markersize=15, color=colors[0], linewidth=linewidth, linestyle="-")
162+
ax7.plot(index, inter_7, marker=markers[1], markevery=1, markersize=15, color=colors[1], linewidth=linewidth, linestyle="-")
163+
ax7.plot(index, exter_7, marker=markers[2], markevery=1, markersize=15, color=colors[2], linewidth=linewidth, linestyle="-")
164+
ax8.plot(index, acc_8, marker=markers[0], markevery=1, markersize=15, color=colors[0], linewidth=linewidth, linestyle="-")
165+
ax8.plot(index, inter_8, marker=markers[1], markevery=1, markersize=15, color=colors[1], linewidth=linewidth, linestyle="-")
166+
ax8.plot(index, exter_8, marker=markers[2], markevery=1, markersize=15, color=colors[2], linewidth=linewidth, linestyle="-")
167+
ax9.plot(index, acc_9, marker=markers[0], markevery=1, markersize=15, color=colors[0], linewidth=linewidth, linestyle="-")
168+
ax9.plot(index, inter_9, marker=markers[1], markevery=1, markersize=15, color=colors[1], linewidth=linewidth, linestyle="-")
169+
ax9.plot(index, exter_9, marker=markers[2], markevery=1, markersize=15, color=colors[2], linewidth=linewidth, linestyle="-")
170+
ax10.plot(index, acc_10, marker=markers[0], markevery=1, markersize=15, color=colors[0], linewidth=linewidth, linestyle="-")
171+
ax10.plot(index, inter_10, marker=markers[1], markevery=1, markersize=15, color=colors[1], linewidth=linewidth, linestyle="-")
172+
ax10.plot(index, exter_10, marker=markers[2], markevery=1, markersize=15, color=colors[2], linewidth=linewidth, linestyle="-")
173+
174+
175+
plt.tight_layout()
176+
plt.savefig("ablation.pdf")
177+
plt.show()

0 commit comments

Comments
 (0)