Skip to content

Commit 76584e5

Browse files
author
wbw520
committed
modified mnist
1 parent 1b1facc commit 76584e5

File tree

5 files changed

+17
-3
lines changed

5 files changed

+17
-3
lines changed

configs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
2929
parser.add_argument('--lr', default=0.0001, type=float)
3030
parser.add_argument('--batch_size', default=128, type=int)
31-
parser.add_argument('--epoch', default=80, type=int)
32-
parser.add_argument('--lr_drop', default=60, type=float, nargs="+",
31+
parser.add_argument('--epoch', default=100, type=int)
32+
parser.add_argument('--lr_drop', default=80, type=float, nargs="+",
3333
metavar='LRSteps', help='epochs to decay learning rate by 10')
3434
# ========================= Machine Configs ==========================
3535
parser.add_argument('--num_workers', default=4, type=int)

demo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from flask import Flask

main_retri.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def main():
1919
checkpoint = torch.load(os.path.join(args.output_dir,
2020
f"{args.dataset}_{args.base_model}_cls{args.num_classes}_cpt_no_slot.pt"), map_location=device)
2121
model.load_state_dict(checkpoint, strict=False)
22-
fix_parameter(model, ["layer1", "layer2", "layer3", "back_bone.conv1", "back_bone.bn1"], mode="fix")
22+
fix_parameter(model, ["layer1", "layer2", "back_bone.conv1", "back_bone.bn1"], mode="fix")
2323
print(colored('trainable parameter name: ', "blue"))
2424
print_param(model)
2525
print("load pre-trained model finished, start training")

process.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@ def main():
3636
args = parser.parse_args()
3737
args.process = True
3838
args.pre_train = False
39+
# args.batch_size = 1
3940
main()
4041

utils/tools.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,22 @@ def predict_hash_code(args, model, data_loader, device):
6969
data, label = data.to(device), label.to(device)
7070
if not args.pre_train:
7171
cpt, pred, att, update = model(data)
72+
# if args.process:
73+
# att = att[0]
74+
# record_cpt = []
75+
# for i in range(args.num_cpt):
76+
# current_mean = att[i].mean()
77+
# incidence = att[i] > current_mean
78+
# new_att = att[i][incidence]
79+
# record_cpt.append(new_att.sum().unsqueeze(0)/len(incidence))
80+
# cpt = torch.cat(record_cpt, dim=0)
81+
# cpt = cpt.unsqueeze(0).squeeze(-1)
82+
7283
acc = cal_acc(pred, label, False)
7384
accs += acc
7485
else:
7586
cpt = model(data)
87+
7688
if is_start:
7789
all_output = cpt.cpu().detach().float()
7890
all_label = label.unsqueeze(-1).cpu().detach().float()

0 commit comments

Comments
 (0)