Skip to content

Commit c6a67ea

Browse files
author
wbw
committed
CHANGES
1 parent 67b3d33 commit c6a67ea

File tree

5 files changed

+15
-13
lines changed

5 files changed

+15
-13
lines changed

draw_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
def draw_bar(data):
11-
plt.figure(figsize=(8, 6), dpi=80)
11+
plt.figure(figsize=(20, 6), dpi=80)
1212
font = 22
1313
x_bar = np.arange(0, len(data), 1)
1414
plt.bar(x_bar, data)
@@ -22,7 +22,7 @@ def draw_bar(data):
2222

2323
def draw_plot(data):
2424
font = 22
25-
plt.figure(figsize=(8, 6), dpi=80)
25+
plt.figure(figsize=(20, 6), dpi=80)
2626
b, c = data.shape
2727
for i in range(c):
2828
plt.boxplot(data[:, i], positions=[i*10], widths=5, showmeans=True)

engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,17 @@ def train(args, model, device, loader, optimizer, epoch):
4848
pred_acces.update(acc)
4949

5050
if epoch >= args.lr_drop:
51-
s = 0
52-
k = 2
51+
s = 5
52+
k = 5
5353
q = 5
5454
t = 2
5555
else:
56-
s = 0
56+
s = 1
5757
k = 1
5858
q = 1
59-
t = 0.5
59+
t = 1
6060

61-
loss_total = retri_loss + s * attn_loss + t * quantity_loss + 0.5 * loss_pred + q * consistence_loss - k * batch_dis_loss + 0 * att_dis_loss
61+
loss_total = retri_loss + s * attn_loss + t * quantity_loss + 0.5 * loss_pred - q * consistence_loss + k * batch_dis_loss
6262
else:
6363
cpt = model(data)
6464
retri_loss, quantity_loss = get_retrieval_loss(cpt, label, args.num_classes, device)

loss.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from torch.autograd import Variable
3+
import torch.nn.functional as F
34

45

56
def pairwise_loss(outputs1, outputs2, label1, label2, sigmoid_param=1.):
@@ -44,10 +45,10 @@ def batch_cpt_discriminate(data, att):
4445
current_att = att.sum(-1)[:, i]
4546
indices = current_att > current_att.mean()
4647
b, d = current_f[indices].shape
47-
current_f = current_f[indices] # / current_att[indices].unsqueeze(-1).expand(b, d)
48+
current_f = current_f[indices] / current_att[indices].unsqueeze(-1).expand(b, d)
4849
record.append(torch.mean(current_f, dim=0, keepdim=True))
4950
record = torch.cat(record, dim=0)
50-
sim = ((record[None, :, :] - record[:, None, :]) ** 2).sum(-1)
51+
sim = F.cosine_similarity(record[None, :, :], record[:, None, :], dim=-1)
5152
return sim.mean()
5253

5354

@@ -76,8 +77,8 @@ def att_consistence(update, att):
7677
current_att = att[:, i, :].sum(-1)
7778
indices = current_att > current_att.mean()
7879
b, d = current_up[indices].shape
79-
need = current_up[indices] # / current_att[indices].unsqueeze(-1).expand(b, d)
80-
consistence_loss += ((need[None, :, :] - need[:, None, :]) ** 2).sum(-1).mean()
80+
need = current_up[indices] / current_att[indices].unsqueeze(-1).expand(b, d)
81+
consistence_loss += F.cosine_similarity(need[None, :, :], need[:, None, :], dim=-1).mean()
8182
return consistence_loss/cpt
8283

8384

model/slots.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ def forward(self, inputs_pe, inputs, weight=None, things=None):
5555
else:
5656
raise RuntimeError(f"unsupported input to tensor dot, got slot mode={self.slot_mode}")
5757

58+
# attn2 = attn / (attn.sum(dim=-1, keepdim=True) + self.eps)
5859
updates = torch.einsum('bjd,bij->bid', inputs, attn)
59-
updates = updates / inputs.size(2)
60+
# updates = updates / (attn.sum(-1).unsqueeze(-1) + 1)
6061

6162
if self.vis:
6263
slots_vis_raw = attn.clone()

vis_retri.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def main():
9090
test_labels = f1["test_labels"]
9191

9292
query = np.zeros((1, args.num_cpt)) - 1
93-
location = 2
93+
location = 14
9494
query[0][location] = 1
9595
ids = for_retrival(args, np.array(database_hash), query, location=location)
9696
print("-------------------------")

0 commit comments

Comments
 (0)