-
Notifications
You must be signed in to change notification settings - Fork 3k
/
Copy pathtrain.py
121 lines (104 loc) · 3.58 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import argparse
import dgl
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import AddSelfLoop
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
class GCN(nn.Module):
def __init__(self, in_size, hid_size, out_size):
super().__init__()
self.layers = nn.ModuleList()
# two-layer GCN
self.layers.append(
dglnn.GraphConv(in_size, hid_size, activation=F.relu)
)
self.layers.append(dglnn.GraphConv(hid_size, out_size))
self.dropout = nn.Dropout(0.5)
def forward(self, g, features):
h = features
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(g, h)
return h
def evaluate(g, features, labels, mask, model):
model.eval()
with torch.no_grad():
logits = model(g, features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def train(g, features, labels, masks, model):
# define train/val samples, loss function and optimizer
train_mask = masks[0]
val_mask = masks[1]
loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
# training loop
for epoch in range(200):
model.train()
logits = model(g, features)
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = evaluate(g, features, labels, val_mask, model)
print(
"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
epoch, loss.item(), acc
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default="cora",
help="Dataset name ('cora', 'citeseer', 'pubmed').",
)
parser.add_argument(
"--dt",
type=str,
default="float",
help="data type(float, bfloat16)",
)
args = parser.parse_args()
print(f"Training with DGL built-in GraphConv module.")
# load and preprocess dataset
transform = (
AddSelfLoop()
) # by default, it will first remove self-loops to prevent duplication
if args.dataset == "cora":
data = CoraGraphDataset(transform=transform)
elif args.dataset == "citeseer":
data = CiteseerGraphDataset(transform=transform)
elif args.dataset == "pubmed":
data = PubmedGraphDataset(transform=transform)
else:
raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
g = g.int().to(device)
features = g.ndata["feat"]
labels = g.ndata["label"]
masks = g.ndata["train_mask"], g.ndata["val_mask"], g.ndata["test_mask"]
# create GCN model
in_size = features.shape[1]
out_size = data.num_classes
model = GCN(in_size, 16, out_size).to(device)
# convert model and graph to bfloat16 if needed
if args.dt == "bfloat16":
g = dgl.to_bfloat16(g)
features = features.to(dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)
# model training
print("Training...")
train(g, features, labels, masks, model)
# test the model
print("Testing...")
acc = evaluate(g, features, labels, masks[2], model)
print("Test accuracy {:.4f}".format(acc))