-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathutil.py
49 lines (40 loc) · 1.47 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# -*- coding:utf-8 -*-
import copy
import numpy as np
import torch
from tqdm import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@torch.no_grad()
def test(model, data):
model.eval()
out = model(data)
loss_function = torch.nn.CrossEntropyLoss().to(device)
loss = loss_function(out[data.val_mask], data.y[data.val_mask])
_, pred = out.max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / int(data.test_mask.sum())
model.train()
return loss.item(), acc
def train(model, data):
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
loss_function = torch.nn.CrossEntropyLoss().to(device)
min_val_loss = np.Inf
best_model = None
min_epochs = 5
model.train()
final_test_acc = 0
for epoch in tqdm(range(200)):
out = model(data)
optimizer.zero_grad()
loss = loss_function(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
# validation
val_loss, test_acc = test(model, data)
if val_loss < min_val_loss and epoch + 1 > min_epochs:
min_val_loss = val_loss
final_test_acc = test_acc
best_model = copy.deepcopy(model)
tqdm.write('Epoch {:03d} train_loss {:.4f} val_loss {:.4f} test_acc {:.4f}'
.format(epoch, loss.item(), val_loss, test_acc))
return best_model, final_test_acc