Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace NeighborSampler with NeighborLoader in mag240m #382

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
304 changes: 128 additions & 176 deletions examples/lsc/mag240m/gnn.py
Original file line number Diff line number Diff line change
@@ -13,40 +13,32 @@
from torch.nn import ModuleList, Sequential, Linear, BatchNorm1d, ReLU, Dropout
from torch.optim.lr_scheduler import StepLR

from pytorch_lightning.metrics import Accuracy
from torchmetrics import Accuracy
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import (LightningDataModule, LightningModule, Trainer,
seed_everything)

from torch_sparse import SparseTensor
from torch_geometric.nn import SAGEConv, GATConv
from torch_geometric.nn import SAGEConv, GATConv, to_hetero
from torch_geometric.data import NeighborSampler

from ogb.lsc import MAG240MDataset, MAG240MEvaluator
from root import ROOT


class Batch(NamedTuple):
x: Tensor
y: Tensor
adjs_t: List[SparseTensor]

def to(self, *args, **kwargs):
return Batch(
x=self.x.to(*args, **kwargs),
y=self.y.to(*args, **kwargs),
adjs_t=[adj_t.to(*args, **kwargs) for adj_t in self.adjs_t],
)


class MAG240M(LightningDataModule):
def __init__(self, data_dir: str, batch_size: int, sizes: List[int],
in_memory: bool = False):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.sizes = sizes
self.in_memory = in_memory
from torch_geometric.loader.neighbor_loader import NeighborLoader
from torch_geometric.typing import Adj
import torch_geometric.transforms as T
from torch_geometric.typing import EdgeType, NodeType
from typing import Dict, Tuple
from torch_geometric.data import Batch
from torch_geometric.data import LightningNodeData
import pathlib
from torch.profiler import ProfilerActivity, profile

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class MAG240M(LightningNodeData):
def __init__(self, *args, **kwargs):
super(MAG240M, self).__init__(*args, **kwargs)

@property
def num_features(self) -> int:
@@ -56,187 +48,142 @@ def num_features(self) -> int:
def num_classes(self) -> int:
return 153

def prepare_data(self):
dataset = MAG240MDataset(self.data_dir)
path = f'{dataset.dir}/paper_to_paper_symmetric.pt'
if not osp.exists(path):
t = time.perf_counter()
print('Converting adjacency matrix...', end=' ', flush=True)
edge_index = dataset.edge_index('paper', 'cites', 'paper')
edge_index = torch.from_numpy(edge_index)
adj_t = SparseTensor(
row=edge_index[0], col=edge_index[1],
sparse_sizes=(dataset.num_papers, dataset.num_papers),
is_sorted=True)
torch.save(adj_t.to_symmetric(), path)
print(f'Done! [{time.perf_counter() - t:.2f}s]')

def setup(self, stage: Optional[str] = None):
t = time.perf_counter()
print('Reading dataset...', end=' ', flush=True)
dataset = MAG240MDataset(self.data_dir)

self.train_idx = torch.from_numpy(dataset.get_idx_split('train'))
self.train_idx = self.train_idx
self.train_idx.share_memory_()
self.val_idx = torch.from_numpy(dataset.get_idx_split('valid'))
self.val_idx.share_memory_()
self.test_idx = torch.from_numpy(dataset.get_idx_split('test-dev'))
self.test_idx.share_memory_()

if self.in_memory:
self.x = torch.from_numpy(dataset.all_paper_feat).share_memory_()
else:
self.x = dataset.paper_feat
self.y = torch.from_numpy(dataset.all_paper_label)

path = f'{dataset.dir}/paper_to_paper_symmetric.pt'
self.adj_t = torch.load(path)
print(f'Done! [{time.perf_counter() - t:.2f}s]')

def train_dataloader(self):
return NeighborSampler(self.adj_t, node_idx=self.train_idx,
sizes=self.sizes, return_e_id=False,
transform=self.convert_batch,
batch_size=self.batch_size, shuffle=True,
num_workers=4)

def val_dataloader(self):
return NeighborSampler(self.adj_t, node_idx=self.val_idx,
sizes=self.sizes, return_e_id=False,
transform=self.convert_batch,
batch_size=self.batch_size, num_workers=2)

def test_dataloader(self): # Test best validation model once again.
return NeighborSampler(self.adj_t, node_idx=self.val_idx,
sizes=self.sizes, return_e_id=False,
transform=self.convert_batch,
batch_size=self.batch_size, num_workers=2)

def hidden_test_dataloader(self):
return NeighborSampler(self.adj_t, node_idx=self.test_idx,
sizes=self.sizes, return_e_id=False,
transform=self.convert_batch,
batch_size=self.batch_size, num_workers=3)

def convert_batch(self, batch_size, n_id, adjs):
if self.in_memory:
x = self.x[n_id].to(torch.float)
else:
x = torch.from_numpy(self.x[n_id.numpy()]).to(torch.float)
y = self.y[n_id[:batch_size]].to(torch.long)
return Batch(x=x, y=y, adjs_t=[adj_t for adj_t, _, _ in adjs])


class GNN(LightningModule):
def metadata(self) -> Tuple[List[NodeType], List[EdgeType]]:
node_types = ['paper', 'author', 'institution']
edge_types = [
('author', 'affiliated_with', 'institution'),
('institution', 'rev_affiliated_with', 'author'),
('author', 'writes', 'paper'),
('paper', 'rev_writes', 'author'),
('paper', 'cites', 'paper'),
]
return node_types, edge_types

class GNN(torch.nn.Module):
def __init__(self, model: str, in_channels: int, out_channels: int,
hidden_channels: int, num_layers: int, heads: int = 4,
dropout: float = 0.5):
super().__init__()
self.save_hyperparameters()
self.model = model.lower()
self.dropout = dropout

self.convs = ModuleList()
self.norms = ModuleList()
self.skips = ModuleList()

if self.model == 'gat':
self.convs.append(
GATConv(in_channels, hidden_channels // heads, heads))
self.skips.append(Linear(in_channels, hidden_channels))
for _ in range(num_layers - 1):
self.convs.append(
GATConv(hidden_channels, hidden_channels // heads, heads))
self.skips.append(Linear(hidden_channels, hidden_channels))

elif self.model == 'graphsage':
self.convs.append(SAGEConv(in_channels, hidden_channels))
for _ in range(num_layers - 1):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))

for _ in range(num_layers):
self.norms.append(BatchNorm1d(hidden_channels))

self.mlp = Sequential(
Linear(hidden_channels, hidden_channels),
BatchNorm1d(hidden_channels),
ReLU(inplace=True),
Dropout(p=self.dropout),
Linear(hidden_channels, out_channels),
)

self.num_layers = num_layers

self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, hidden_channels)
self.lin = Linear(hidden_channels, out_channels)

def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
x = x.to(torch.float)
x = self.conv1(x, edge_index).relu()
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv2(x, edge_index).relu()
x = F.dropout(x, p=self.dropout, training=self.training)
return self.lin(x)

class HeteroGNN(LightningModule):
def __init__(self, model_name: str, metadata: Tuple[List[NodeType], List[EdgeType]], in_channels: int, out_channels: int,
hidden_channels: int, num_layers: int, heads: int = 4,
dropout: float = 0.5):
super().__init__()
self.save_hyperparameters()
model = GNN(model_name, in_channels, out_channels, hidden_channels, num_layers, heads=heads, dropout=dropout)
self.model = to_hetero(model, metadata, aggr='sum', debug=True).to(device)
self.train_acc = Accuracy()
self.val_acc = Accuracy()
self.test_acc = Accuracy()

def forward(self, x: Tensor, adjs_t: List[SparseTensor]) -> Tensor:
for i, adj_t in enumerate(adjs_t):
x_target = x[:adj_t.size(0)]
x = self.convs[i]((x, x_target), adj_t)
if self.model == 'gat':
x = x + self.skips[i](x_target)
x = F.elu(self.norms[i](x))
elif self.model == 'graphsage':
x = F.relu(self.norms[i](x))
x = F.dropout(x, p=self.dropout, training=self.training)

return self.mlp(x)

def training_step(self, batch, batch_idx: int):
y_hat = self(batch.x, batch.adjs_t)
train_loss = F.cross_entropy(y_hat, batch.y)
self.train_acc(y_hat.softmax(dim=-1), batch.y)
def forward(
self,
x_dict: Dict[NodeType, Tensor],
edge_index_dict: Dict[EdgeType, Tensor],
) -> Dict[NodeType, Tensor]:
return self.model(x_dict, edge_index_dict)

def common_step(self, batch: Batch) -> Tuple[Tensor, Tensor]:
batch_size = batch['paper'].batch_size
y_hat = self(batch.x_dict, batch.edge_index_dict)['paper'][:batch_size]
y = batch['paper'].y[:batch_size].to(torch.long)
return y_hat, y

def training_step(self, batch: Batch, batch_idx: int) -> Tensor:
y_hat, y = self.common_step(batch)
train_loss = F.cross_entropy(y_hat, y)
self.train_acc(y_hat.softmax(dim=-1), y)
self.log('train_acc', self.train_acc, prog_bar=True, on_step=False,
on_epoch=True)
return train_loss

def validation_step(self, batch, batch_idx: int):
y_hat = self(batch.x, batch.adjs_t)
self.val_acc(y_hat.softmax(dim=-1), batch.y)
def validation_step(self, batch: Batch, batch_idx: int):
y_hat, y = self.common_step(batch)
self.val_acc(y_hat.softmax(dim=-1), y)
self.log('val_acc', self.val_acc, on_step=False, on_epoch=True,
prog_bar=True, sync_dist=True)

def test_step(self, batch, batch_idx: int):
y_hat = self(batch.x, batch.adjs_t)
self.test_acc(y_hat.softmax(dim=-1), batch.y)
def test_step(self, batch: Batch, batch_idx: int):
y_hat, y = self.common_step(batch)
self.test_acc(y_hat.softmax(dim=-1), y)
self.log('test_acc', self.test_acc, on_step=False, on_epoch=True,
prog_bar=True, sync_dist=True)

def predict_step(self, batch: Batch, batch_idx: int):
y_hat, y = self.common_step(batch)
return y_hat

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=25, gamma=0.25)
return [optimizer], [scheduler]

def trace_handler(p):
if torch.cuda.is_available():
profile_sort = 'self_cuda_time_total'
else:
profile_sort = 'self_cpu_time_total'
output = p.key_averages().table(sort_by=profile_sort)
print(output)
profile_dir = str(pathlib.Path.cwd()) + '/'
timeline_file = profile_dir + 'timeline' + '.json'
p.export_chrome_trace(timeline_file)

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--hidden_channels', type=int, default=1024)
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--model', type=str, default='gat',
choices=['gat', 'graphsage'])
parser.add_argument('--sizes', type=str, default='25-15')
parser.add_argument('--sizes', type=str, default='2')
parser.add_argument('--in-memory', action='store_true')
parser.add_argument('--device', type=str, default='0')
parser.add_argument('--evaluate', action='store_true')
parser.add_argument('--profile', action='store_true')
args = parser.parse_args()
args.sizes = [int(i) for i in args.sizes.split('-')]
print(args)

seed_everything(42)
datamodule = MAG240M(ROOT, args.batch_size, args.sizes, args.in_memory)
dataset = MAG240MDataset(ROOT)
data = dataset.to_pyg_hetero_data()
datamodule = MAG240M(data, ('paper', data['paper'].train_mask),
('paper', data['paper'].val_mask),
('paper', data['paper'].test_mask),
('paper', data['paper'].test_mask),
loader='neighbor', num_neighbors=args.sizes,
batch_size=args.batch_size, num_workers=2)
print(datamodule)

if not args.evaluate:
model = GNN(args.model, datamodule.num_features,
model = HeteroGNN(args.model, datamodule.metadata(), datamodule.num_features,
datamodule.num_classes, args.hidden_channels,
num_layers=len(args.sizes), dropout=args.dropout)
print(f'#Params {sum([p.numel() for p in model.parameters()])}')
checkpoint_callback = ModelCheckpoint(monitor='val_acc', mode = 'max', save_top_k=1)
trainer = Trainer(gpus=args.device, max_epochs=args.epochs,
trainer = Trainer(accelerator="cpu", max_epochs=args.epochs,
callbacks=[checkpoint_callback],
default_root_dir=f'logs/{args.model}')
default_root_dir=f'logs/{args.model}',
limit_train_batches=10, limit_test_batches=10,
limit_val_batches=10, limit_predict_batches=10)
trainer.fit(model, datamodule=datamodule)

if args.evaluate:
@@ -246,26 +193,31 @@ def configure_optimizers(self):
print(f'Evaluating saved model in {logdir}...')
ckpt = glob.glob(f'{logdir}/checkpoints/*')[0]

trainer = Trainer(gpus=args.device, resume_from_checkpoint=ckpt)
model = GNN.load_from_checkpoint(checkpoint_path=ckpt,
trainer = Trainer(accelerator="cpu", resume_from_checkpoint=ckpt)
model = HeteroGNN.load_from_checkpoint(checkpoint_path=ckpt,
hparams_file=f'{logdir}/hparams.yaml')

datamodule.batch_size = 16
datamodule.sizes = [160] * len(args.sizes) # (Almost) no sampling...

trainer.test(model=model, datamodule=datamodule)

evaluator = MAG240MEvaluator()
loader = datamodule.hidden_test_dataloader()

model.eval()
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
model.to(device)
y_preds = []
for batch in tqdm(loader):
batch = batch.to(device)
with torch.no_grad():
out = model(batch.x, batch.adjs_t).argmax(dim=-1).cpu()
y_preds.append(out)
res = {'y_pred': torch.cat(y_preds, dim=0)}
evaluator.save_test_submission(res, f'results/{args.model}', mode = 'test-dev')
trainer.predict(model=model, datamodule=datamodule)
if args.profile:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
on_trace_ready=trace_handler) as p:
trainer.predict(model=model, datamodule=datamodule)
p.step()

# evaluator = MAG240MEvaluator()
# loader = datamodule.hidden_test_dataloader()

# model.eval()
# device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
# model.to(device)
# y_preds = []
# for batch in tqdm(loader):
# batch = batch.to(device)
# with torch.no_grad():
# out = model(batch.x, batch.adjs_t).argmax(dim=-1).cpu()
# y_preds.append(out)
# res = {'y_pred': torch.cat(y_preds, dim=0)}
# evaluator.save_test_submission(res, f'results/{args.model}', mode = 'test-dev')
72 changes: 51 additions & 21 deletions ogb/lsc/mag240m.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@

from ogb.utils.url import decide_download, download_url, extract_zip, makedirs
from ogb.lsc.utils import split_test
from torch_geometric.data import HeteroData


class MAG240MDataset(object):
@@ -53,6 +54,40 @@ def download(self):
print('Stop download.')
exit(-1)

def to_pyg_hetero_data(self):
data = HeteroData()
path = osp.join(self.dir, 'processed', 'paper', 'node_feat.npy')
# Current is not in-memory
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean:

```suggestion
        # Currently in-memory only

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data["paper"].x = torch.from_numpy(np.load(path, mmap_mode='r')) is from @property def paper_label(self)..., which is called when self.in_memory is False. So I comment here, to remind myself to enable in_memory part.

data['paper'].x = torch.from_numpy(np.load(path, mmap_mode='r'))
path = osp.join(self.dir, 'processed', 'paper', 'node_label.npy')
data['paper'].y = torch.from_numpy(np.load(path))
path = osp.join(self.dir, 'processed', 'paper', 'node_year.npy')
data['paper'].year = torch.from_numpy(np.load(path, mmap_mode='r'))

data['author'].num_nodes = self.__meta__['author']
path = osp.join(self.dir, 'processed', 'author', 'author.npy')
data['author'].x = np.memmap(path, mode='r', dtype="float16", shape=(data['author'].num_nodes, self.num_paper_features))
data['institution'].num_nodes = self.__meta__['institution']
path = osp.join(self.dir, 'processed', 'institution', 'inst.npy')
data['institution'].x = np.memmap(path, mode='r', dtype="float16", shape=(data['institution'].num_nodes, self.num_paper_features))

for edge_type in [('author', 'affiliated_with', 'institution'),
('author', 'writes', 'paper'),
('paper', 'cites', 'paper')]:
name = '___'.join(edge_type)
path = osp.join(self.dir, 'processed', name, 'edge_index.npy')
edge_index = torch.from_numpy(np.load(path))
data[edge_type].edge_index = edge_index
data[edge_type[2], f'rev_{edge_type[1]}', edge_type[0]].edge_index = edge_index.flip([0])

for f, v in [('train', 'train'), ('valid', 'val'), ('test-dev', 'test')]:
idx = self.get_idx_split(f)
idx = torch.from_numpy(idx)
mask = torch.zeros(data['paper'].num_nodes, dtype=torch.bool)
mask[idx] = True
data['paper'][f'{v}_mask'] = mask
return data

@property
def num_papers(self) -> int:
return self.__meta__['paper']
@@ -108,15 +143,6 @@ def all_paper_year(self) -> np.ndarray:
path = osp.join(self.dir, 'processed', 'paper', 'node_year.npy')
return np.load(path)

def edge_index(self, id1: str, id2: str,
id3: Optional[str] = None) -> np.ndarray:
src = id1
rel, dst = (id3, id2) if id3 is None else (id2, id3)
rel = self.__rels__[(src, dst)] if rel is None else rel
name = f'{src}___{rel}___{dst}'
path = osp.join(self.dir, 'processed', name, 'edge_index.npy')
return np.load(path)

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'

@@ -164,6 +190,7 @@ def save_test_submission(self, input_dict: Dict, dir_path: str, mode: str):

if __name__ == '__main__':
dataset = MAG240MDataset()
data = dataset.to_pyg_hetero_data()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's test this separately?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/home/user/yanbing/pyg/ogb/ogb/lsc/dataset is the dev root, will remove it.

print(dataset)
print(dataset.num_papers)
print(dataset.num_authors)
@@ -196,22 +223,25 @@ def save_test_submission(self, input_dict: Dict, dir_path: str, mode: str):

exit(-1)

print(dataset.paper_feat.shape)
print(dataset.paper_year.shape)
print(dataset.paper_year[:100])
print(dataset.edge_index('author', 'paper').shape)
print(dataset.edge_index('author', 'writes', 'paper').shape)
print(dataset.edge_index('author', 'writes', 'paper')[:, :10])
print(data['paper'].x.shape)
print(data['paper'].year.shape)
print(data['paper'].year[:100])
print(data[(('author', 'writes', 'paper'))].edge_index.shape)
print(data[('author', 'affiliated_with', 'institution')].edge_index.shape)
print(data[('paper', 'cites', 'paper')].edge_index.shape)
print(data[('author', 'writes', 'paper')].edge_index[:, :10])
print(data[('author', 'affiliated_with', 'institution')].edge_index[:, :10])
print(data[('paper', 'cites', 'paper')].edge_index[:, :10])
print('-----------------')

train_idx = dataset.get_idx_split('train')
val_idx = dataset.get_idx_split('valid')
test_idx = dataset.get_idx_split('test-whole')
print(len(train_idx) + len(val_idx) + len(test_idx))

print(dataset.paper_label[train_idx][:10])
print(dataset.paper_label[val_idx][:10])
print(dataset.paper_label[test_idx][:10])
print(dataset.paper_year[train_idx][:10])
print(dataset.paper_year[val_idx][:10])
print(dataset.paper_year[test_idx][:10])
print(data['paper'].y[train_idx][:10])
print(data['paper'].y[val_idx][:10])
print(data['paper'].y[test_idx][:10])
print(data['paper'].year[train_idx][:10])
print(data['paper'].year[val_idx][:10])
print(data['paper'].year[test_idx][:10])