Skip to content

Commit 080a6d9

Browse files
committed
Update train_mnist.py
1 parent 6772ead commit 080a6d9

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

train_mnist.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,16 @@ def printwrite(x):
3535
printwrite("Runing Attention Model...")
3636

3737
#learning rate is 1e-3
38-
for lr in [1e-5]:
38+
for lr in [1e-3]:
3939

4040
#penalty parameter = 1e3
41-
for p in [0, 1e2, 1e3, 1e4, 1e5]:
41+
for p in [1e3]:
4242
modelname = "%s_penalty_%s_lr_%s"%(run_id, p, lr)
4343
best_model = None
4444
best_score = -1
4545

4646
#run for a single trial
47-
for k in range(5):
47+
for k in range(1):
4848
printwrite("[MODELNAME %s TRIAL %s]"%(modelname, k))
4949
net = Net((28, 28 + 14*(args.n -1)), strength = args.strength).to(device)
5050
optimizer = torch.optim.Adam(net.parameters(), lr = lr)

0 commit comments

Comments
 (0)