12
12
import numpy as np
13
13
import pytorch_lightning as pl
14
14
import torch
15
- import torch .nn .functional as F
16
15
from hydra .utils import instantiate
17
16
from omegaconf import OmegaConf
18
17
from scipy .spatial .distance import pdist , squareform
25
24
from graphany .utils import logger , timer
26
25
27
26
28
- def get_entropy_normed_cond_gaussian_prob (
29
- X , entropy , beta_list = None , metric = "euclidean" , use_cpython = False , return_beta = False
30
- ):
27
+ def get_entropy_normed_cond_gaussian_prob (X , entropy , metric = "euclidean" ):
31
28
"""
32
29
Parameters
33
30
----------
34
31
X: The matrix for pairwise similarity
35
32
entropy: Perplexity of the conditional prob distribution
36
- Returns conditional probability
33
+ Returns the entropy-normalized conditional gaussian probability based on distances.
37
34
-------
38
35
"""
39
36
@@ -175,30 +172,28 @@ def val_dataloader(self):
175
172
sub_dataloaders = {
176
173
name : ds .val_dataloader () for name , ds in self .eval_ds_dict .items ()
177
174
}
178
- return pl .utilities .combined_loader .CombinedLoader (
179
- sub_dataloaders , "max_size"
180
- ) # Use max_size instead of max_size_cycle to avoid duplicates
175
+ # Use max_size instead of max_size_cycle to avoid repeated evaluation on small datasets
176
+ return pl .utilities .combined_loader .CombinedLoader (sub_dataloaders , "max_size" )
181
177
182
178
def test_dataloader (self ):
183
179
sub_dataloaders = {
184
180
name : ds .test_dataloader () for name , ds in self .eval_ds_dict .items ()
185
181
}
186
- return pl .utilities .combined_loader .CombinedLoader (
187
- sub_dataloaders , "max_size"
188
- ) # Use max_size instead of max_size_cycle to avoid duplicates
182
+ # Use max_size instead of max_size_cycle to avoid repeated evaluation on small datasets
183
+ return pl .utilities .combined_loader .CombinedLoader (sub_dataloaders , "max_size" )
189
184
190
185
191
186
class GraphDataset (pl .LightningDataModule ):
192
187
def __init__ (
193
- self ,
194
- cfg ,
195
- ds_name ,
196
- cache_dir ,
197
- train_batch_size = 256 ,
198
- val_test_batch_size = 256 ,
199
- n_hops = 1 ,
200
- preprocess_device = torch .device ("cpu" ),
201
- permute_label = False ,
188
+ self ,
189
+ cfg ,
190
+ ds_name ,
191
+ cache_dir ,
192
+ train_batch_size = 256 ,
193
+ val_test_batch_size = 256 ,
194
+ n_hops = 1 ,
195
+ preprocess_device = torch .device ("cpu" ),
196
+ permute_label = False ,
202
197
):
203
198
super ().__init__ ()
204
199
self .cfg = cfg
@@ -369,9 +364,9 @@ def to_mask(indices):
369
364
label = dataset .y
370
365
371
366
if (
372
- hasattr (dataset , "train_mask" )
373
- and hasattr (dataset , "val_mask" )
374
- and hasattr (dataset , "test_mask" )
367
+ hasattr (dataset , "train_mask" )
368
+ and hasattr (dataset , "val_mask" )
369
+ and hasattr (dataset , "test_mask" )
375
370
):
376
371
train_mask , val_mask , test_mask = (
377
372
dataset .train_mask ,
@@ -399,9 +394,8 @@ def to_mask(indices):
399
394
# ! Multiple splits
400
395
# Modified: Use the ${seed} split if not specified!
401
396
split_index = self .data_init_args .get ("split" , self .cfg .seed )
402
- self .split_index = split_index = (
403
- split_index % train_mask .ndim
404
- ) # Avoid invalid seed value
397
+ # Avoid invalid split index
398
+ self .split_index = split_index = (split_index % train_mask .ndim )
405
399
train_mask = train_mask [:, split_index ].squeeze ()
406
400
val_mask = val_mask [:, split_index ].squeeze ()
407
401
if test_mask .ndim == 2 :
@@ -422,29 +416,31 @@ def to_mask(indices):
422
416
return g , label , feat , train_mask , val_mask , test_mask , num_class
423
417
424
418
def compute_linear_gnn_logits (
425
- self , features , n_per_label_examples , visible_nodes , bootstrap = False
419
+ self , features , n_per_label_examples , visible_nodes , bootstrap = False
426
420
):
421
+ # Compute and save LinearGNN logits into a dict. Note the computation is on CPU as torch does not support
422
+ # the gelss driver on GPU currently.
427
423
preds = {}
428
424
label , num_class , device = self .label , self .num_class , torch .device ("cpu" )
429
425
label = label .to (device )
430
426
visible_nodes = visible_nodes .to (device )
431
- for channel , X in features .items ():
432
- X = X .to (device )
427
+ for channel , F in features .items ():
428
+ F = F .to (device )
433
429
if bootstrap :
434
430
ref_nodes = sample_k_nodes_per_label (
435
431
label , visible_nodes , n_per_label_examples , num_class
436
432
)
437
433
else :
438
434
ref_nodes = visible_nodes
439
- Y_L = F .one_hot (label [ref_nodes ], num_class ).float ()
435
+ Y_L = torch . nn . functional .one_hot (label [ref_nodes ], num_class ).float ()
440
436
with timer (
441
- f"Solving with CPU driver (N={ len (ref_nodes )} , d={ X .shape [1 ]} , k={ num_class } )" ,
442
- logger .debug ,
437
+ f"Solving with CPU driver (N={ len (ref_nodes )} , d={ F .shape [1 ]} , k={ num_class } )" ,
438
+ logger .debug ,
443
439
):
444
440
W = torch .linalg .lstsq (
445
- X [ref_nodes .cpu ()].cpu (), Y_L .cpu (), driver = "gelss"
441
+ F [ref_nodes .cpu ()].cpu (), Y_L .cpu (), driver = "gelss"
446
442
)[0 ]
447
- preds [channel ] = X @ W
443
+ preds [channel ] = F @ W
448
444
449
445
return preds
450
446
@@ -466,8 +462,8 @@ def prepare_prop_features_logits_and_dist_features(self, g, input_feats, n_hops)
466
462
if not os .path .exists (self .cache_f_name ):
467
463
g = g .to (self .preprocess_device )
468
464
with timer (
469
- f"Computing { self .name } message passing and normalized predictions to file { self .cache_f_name } " ,
470
- logger .info ,
465
+ f"Computing { self .name } message passing and normalized predictions to file { self .cache_f_name } " ,
466
+ logger .info ,
471
467
):
472
468
dim = input_feats .size (1 )
473
469
LP = torch .zeros (n_hops , g .number_of_nodes (), dim ).to (
@@ -504,9 +500,9 @@ def prepare_prop_features_logits_and_dist_features(self, g, input_feats, n_hops)
504
500
features , unmasked_pred = torch .load (self .cache_f_name , map_location = "cpu" )
505
501
if not os .path .exists (self .dist_f_name ):
506
502
with timer (
507
- f"Computing { self .name } conditional gaussian distances "
508
- f"to file { self .dist_f_name } " ,
509
- logger .info ,
503
+ f"Computing { self .name } conditional gaussian distances "
504
+ f"and save to { self .dist_f_name } " ,
505
+ logger .info ,
510
506
):
511
507
# y_feat: n_nodes, n_channels, n_labels
512
508
y_feat = np .stack (
@@ -532,7 +528,6 @@ def prepare_prop_features_logits_and_dist_features(self, g, input_feats, n_hops)
532
528
dist [:, pair_index ] = cond_gaussian_prob [:, c , c_prime ]
533
529
pair_index += 1
534
530
535
- # Convert dist to a PyTorch tensor and move it to the same device as y_feat
536
531
dist = torch .from_numpy (dist )
537
532
torch .save (dist , self .dist_f_name )
538
533
else :
0 commit comments