Skip to content

Commit a221f38

Browse files
committed
Updated code comments.
1 parent 041d02f commit a221f38

File tree

4 files changed

+46
-56
lines changed

4 files changed

+46
-56
lines changed

configs/data.yaml

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# @package _global_
22

33
# ! Dataset Preprocessing
4+
preprocess_device: gpu # Set to cpu if your GPU memory is below 32GB
45
add_self_loop: false
56
to_bidirected: true
67
n_hops: 2
78

89
# ! Train and Evaluation Dataset Lookup
910
dataset: Debug
10-
preprocess_device: gpu # Set to cpu if your GPU memory is below 32GB
1111
train_datasets: ${oc.select:_dataset_lookup.${dataset}.train,${dataset}}
1212
eval_datasets: ${oc.select:_dataset_lookup.${dataset}.eval,${dataset}}
13-
_trans_datasets: [ Arxiv, Product, Cora, Wisconsin ]
13+
_trans_datasets: [ Arxiv, Product, Cora, Wisconsin ] # Used when identifying heldout datasets.
1414

1515
_all_datasets: [
1616
Arxiv,
@@ -101,17 +101,19 @@ _ds_meta_data:
101101
# WikiTraffic: Nodes represent web pages and edges represent hyperlinks between them. Node features represent several informative nouns in the Wikipedia pages. The task is to predict the average daily traffic of the web page.
102102
Chameleon: pyg, WikipediaNetwork.chameleon # 5201, 217073, 2089, 5
103103
Squirrel: pyg, WikipediaNetwork.squirrel # 2277, 36101, 2325, 5
104-
104+
# Airport traffic graphs
105105
AirBrazil: pyg, Airports.Brazil # 131 1,038 131 4
106106
AirUS: pyg, Airports.USA # 1,190 13,599 1190 4
107107
AirEU: pyg, Airports.Europe # 399 5,995 399 4
108108

109109
# ! HeterophilousGraphDataset
110+
# See https://arxiv.org/abs/2302.11640 for details
110111
Roman: pyg, HeterophilousGraphDataset.Roman-empire # 22,662 32,927 300 18
111112
AmzRatings: pyg, HeterophilousGraphDataset.Amazon-ratings # 24,492 93,050 300 5
112113
Minesweeper: pyg, HeterophilousGraphDataset.Minesweeper # 10,000 39,402 7 2
113114
Tolokers: pyg, HeterophilousGraphDataset.Tolokers # 11,758 519,000 10 2
114115
Questions: pyg, HeterophilousGraphDataset.Questions # 48,921 153,540 301 2
116+
115117
# Each node corresponds to an actor, and the edge between two nodes denotes co-occurrence on the same Wikipedia page. Node features correspond to some keywords in the Wikipedia pages. The task is to classify the nodes into five categories in terms of words of actor’s Wikipedia.
116118
Actor: pyg, Actor # 7,600 30,019 932 5
117119

graphany/data.py

+35-40
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import numpy as np
1313
import pytorch_lightning as pl
1414
import torch
15-
import torch.nn.functional as F
1615
from hydra.utils import instantiate
1716
from omegaconf import OmegaConf
1817
from scipy.spatial.distance import pdist, squareform
@@ -25,15 +24,13 @@
2524
from graphany.utils import logger, timer
2625

2726

28-
def get_entropy_normed_cond_gaussian_prob(
29-
X, entropy, beta_list=None, metric="euclidean", use_cpython=False, return_beta=False
30-
):
27+
def get_entropy_normed_cond_gaussian_prob(X, entropy, metric="euclidean"):
3128
"""
3229
Parameters
3330
----------
3431
X: The matrix for pairwise similarity
3532
entropy: Perplexity of the conditional prob distribution
36-
Returns conditional probability
33+
Returns the entropy-normalized conditional gaussian probability based on distances.
3734
-------
3835
"""
3936

@@ -175,30 +172,28 @@ def val_dataloader(self):
175172
sub_dataloaders = {
176173
name: ds.val_dataloader() for name, ds in self.eval_ds_dict.items()
177174
}
178-
return pl.utilities.combined_loader.CombinedLoader(
179-
sub_dataloaders, "max_size"
180-
) # Use max_size instead of max_size_cycle to avoid duplicates
175+
# Use max_size instead of max_size_cycle to avoid repeated evaluation on small datasets
176+
return pl.utilities.combined_loader.CombinedLoader(sub_dataloaders, "max_size")
181177

182178
def test_dataloader(self):
183179
sub_dataloaders = {
184180
name: ds.test_dataloader() for name, ds in self.eval_ds_dict.items()
185181
}
186-
return pl.utilities.combined_loader.CombinedLoader(
187-
sub_dataloaders, "max_size"
188-
) # Use max_size instead of max_size_cycle to avoid duplicates
182+
# Use max_size instead of max_size_cycle to avoid repeated evaluation on small datasets
183+
return pl.utilities.combined_loader.CombinedLoader(sub_dataloaders, "max_size")
189184

190185

191186
class GraphDataset(pl.LightningDataModule):
192187
def __init__(
193-
self,
194-
cfg,
195-
ds_name,
196-
cache_dir,
197-
train_batch_size=256,
198-
val_test_batch_size=256,
199-
n_hops=1,
200-
preprocess_device=torch.device("cpu"),
201-
permute_label=False,
188+
self,
189+
cfg,
190+
ds_name,
191+
cache_dir,
192+
train_batch_size=256,
193+
val_test_batch_size=256,
194+
n_hops=1,
195+
preprocess_device=torch.device("cpu"),
196+
permute_label=False,
202197
):
203198
super().__init__()
204199
self.cfg = cfg
@@ -369,9 +364,9 @@ def to_mask(indices):
369364
label = dataset.y
370365

371366
if (
372-
hasattr(dataset, "train_mask")
373-
and hasattr(dataset, "val_mask")
374-
and hasattr(dataset, "test_mask")
367+
hasattr(dataset, "train_mask")
368+
and hasattr(dataset, "val_mask")
369+
and hasattr(dataset, "test_mask")
375370
):
376371
train_mask, val_mask, test_mask = (
377372
dataset.train_mask,
@@ -399,9 +394,8 @@ def to_mask(indices):
399394
# ! Multiple splits
400395
# Modified: Use the ${seed} split if not specified!
401396
split_index = self.data_init_args.get("split", self.cfg.seed)
402-
self.split_index = split_index = (
403-
split_index % train_mask.ndim
404-
) # Avoid invalid seed value
397+
# Avoid invalid split index
398+
self.split_index = split_index = (split_index % train_mask.ndim)
405399
train_mask = train_mask[:, split_index].squeeze()
406400
val_mask = val_mask[:, split_index].squeeze()
407401
if test_mask.ndim == 2:
@@ -422,29 +416,31 @@ def to_mask(indices):
422416
return g, label, feat, train_mask, val_mask, test_mask, num_class
423417

424418
def compute_linear_gnn_logits(
425-
self, features, n_per_label_examples, visible_nodes, bootstrap=False
419+
self, features, n_per_label_examples, visible_nodes, bootstrap=False
426420
):
421+
# Compute and save LinearGNN logits into a dict. Note the computation is on CPU as torch does not support
422+
# the gelss driver on GPU currently.
427423
preds = {}
428424
label, num_class, device = self.label, self.num_class, torch.device("cpu")
429425
label = label.to(device)
430426
visible_nodes = visible_nodes.to(device)
431-
for channel, X in features.items():
432-
X = X.to(device)
427+
for channel, F in features.items():
428+
F = F.to(device)
433429
if bootstrap:
434430
ref_nodes = sample_k_nodes_per_label(
435431
label, visible_nodes, n_per_label_examples, num_class
436432
)
437433
else:
438434
ref_nodes = visible_nodes
439-
Y_L = F.one_hot(label[ref_nodes], num_class).float()
435+
Y_L = torch.nn.functional.one_hot(label[ref_nodes], num_class).float()
440436
with timer(
441-
f"Solving with CPU driver (N={len(ref_nodes)}, d={X.shape[1]}, k={num_class})",
442-
logger.debug,
437+
f"Solving with CPU driver (N={len(ref_nodes)}, d={F.shape[1]}, k={num_class})",
438+
logger.debug,
443439
):
444440
W = torch.linalg.lstsq(
445-
X[ref_nodes.cpu()].cpu(), Y_L.cpu(), driver="gelss"
441+
F[ref_nodes.cpu()].cpu(), Y_L.cpu(), driver="gelss"
446442
)[0]
447-
preds[channel] = X @ W
443+
preds[channel] = F @ W
448444

449445
return preds
450446

@@ -466,8 +462,8 @@ def prepare_prop_features_logits_and_dist_features(self, g, input_feats, n_hops)
466462
if not os.path.exists(self.cache_f_name):
467463
g = g.to(self.preprocess_device)
468464
with timer(
469-
f"Computing {self.name} message passing and normalized predictions to file {self.cache_f_name}",
470-
logger.info,
465+
f"Computing {self.name} message passing and normalized predictions to file {self.cache_f_name}",
466+
logger.info,
471467
):
472468
dim = input_feats.size(1)
473469
LP = torch.zeros(n_hops, g.number_of_nodes(), dim).to(
@@ -504,9 +500,9 @@ def prepare_prop_features_logits_and_dist_features(self, g, input_feats, n_hops)
504500
features, unmasked_pred = torch.load(self.cache_f_name, map_location="cpu")
505501
if not os.path.exists(self.dist_f_name):
506502
with timer(
507-
f"Computing {self.name} conditional gaussian distances "
508-
f"to file {self.dist_f_name}",
509-
logger.info,
503+
f"Computing {self.name} conditional gaussian distances "
504+
f"and save to {self.dist_f_name}",
505+
logger.info,
510506
):
511507
# y_feat: n_nodes, n_channels, n_labels
512508
y_feat = np.stack(
@@ -532,7 +528,6 @@ def prepare_prop_features_logits_and_dist_features(self, g, input_feats, n_hops)
532528
dist[:, pair_index] = cond_gaussian_prob[:, c, c_prime]
533529
pair_index += 1
534530

535-
# Convert dist to a PyTorch tensor and move it to the same device as y_feat
536531
dist = torch.from_numpy(dist)
537532
torch.save(dist, self.dist_f_name)
538533
else:

graphany/model.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,9 @@ def compute_dist(self, y_feat):
3737
y_feat[i, :, :].cpu().numpy(), self.entropy
3838
)
3939

40-
# Create dist as a numpy array
40+
# Compute pairwise distances between channels n_channels(n_channels-1)/2 total features
4141
dist = np.zeros((bsz, self.dist_feat_dim), dtype=np.float32)
4242

43-
# Compute pairwise distances between channels n_channels(n_channels-1)/2 total features
4443
pair_index = 0
4544
for c in range(n_channel):
4645
for c_prime in range(n_channel):
@@ -52,14 +51,13 @@ def compute_dist(self, y_feat):
5251
return dist
5352

5453
def forward(self, logit_dict, dist=None, **kwargs):
55-
# Label logits tensor of shape (batch_size, n_channels, * n_classes)
54+
# logit_dict: key: channel, value: prediction of shape (batch_size, n_classes)
5655
y_feat = torch.stack([logit_dict[c] for c in self.feat_channels], dim=1)
5756
y_pred = torch.stack([logit_dict[c] for c in self.pred_channels], dim=1)
5857

5958
# ! Fuse y_pred with attentions
60-
# Compute attention of (batch_size, n_channels)
6159
dist = self.compute_dist(y_feat) if dist is None else dist
62-
# Project pairwise differences to the attention scores via MLP
60+
# Project pairwise differences to the attention scores (batch_size, n_channels)
6361
attention = self.mlp(dist)
6462
attention = th.softmax(attention / self.att_temperature, dim=-1)
6563
fused_y = th.sum(

graphany/run.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
mean = lambda input: np.round(np.mean(input).item(), 2)
1919

2020

21-
class InductiveLabelPred(pl.LightningModule):
21+
class InductiveNodeClassification(pl.LightningModule):
2222
def __init__(self, cfg, combined_dataset, checkpoint=None):
2323
super().__init__()
2424
self.cfg = cfg
@@ -73,8 +73,6 @@ def get_metric_name(self, ds_name, split):
7373
return f"ind/{ds_name.lower()[:4]}_{split}_acc"
7474

7575
def configure_optimizers(self):
76-
num_devices = self.cfg.gpus if self.cfg.gpus > 0 else 1
77-
7876
# start with all the candidate parameters
7977
param_dict = {pn: p for pn, p in self.named_parameters()}
8078
# filter out those that do not require grad
@@ -101,7 +99,7 @@ def configure_optimizers(self):
10199
else: # AdamW
102100
optimizer = torch.optim.AdamW(
103101
optim_groups,
104-
lr=self.cfg.lr * num_devices,
102+
lr=self.cfg.lr,
105103
weight_decay=self.cfg.weight_decay,
106104
)
107105
return optimizer
@@ -117,9 +115,6 @@ def move_metrics_to_device(self):
117115
for metrics_dict in self.metrics.values():
118116
for metric in metrics_dict.values():
119117
metric.to(self.device)
120-
# Example for a direct metric attribute
121-
if hasattr(self, "accuracy"):
122-
self.accuracy.to(self.device)
123118

124119
def predict(self, ds, nodes, input, is_training=False):
125120
# Use preprocessed distance during evaluation
@@ -259,7 +254,7 @@ def construct_ds_dict(datasets):
259254

260255
combined_dataset = CombinedDataset(train_ds_dict, eval_ds_dict, cfg)
261256

262-
model = InductiveLabelPred(cfg, combined_dataset, cfg.get("prev_ckpt"))
257+
model = InductiveNodeClassification(cfg, combined_dataset, cfg.get("prev_ckpt"))
263258
# Set up the checkpoint callback to save only at the end of training
264259
checkpoint_callback = pl.callbacks.ModelCheckpoint(
265260
dirpath=cfg.dirs.output, # specify where to save

0 commit comments

Comments
 (0)