-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace NeighborSampler with NeighborLoader in mag240m #382
Draft
yanbing-j
wants to merge
12
commits into
snap-stanford:master
Choose a base branch
from
yanbing-j:yanbing/enable
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+179
−197
Draft
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
71e07f5
Enable
yanbing-j 1264086
MAG240M to HeteroData
yanbing-j 10ee86a
Try convert model to_hetero
yanbing-j cfdf217
Add author and institution as node types
yanbing-j 7480eea
Use LightningNodeData
yanbing-j 1232606
Add inst.npy
yanbing-j 70f149a
Add author.npy
yanbing-j e93c89c
Add 3 nodes and remove relu/dropout in model
yanbing-j fd26879
add edges
yanbing-j 2dfa0fb
reverse edge types
yanbing-j a842429
reverse edge types and convert y to long
yanbing-j d6d0fd0
Use trainer.predict to run inference
yanbing-j File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
|
||
from ogb.utils.url import decide_download, download_url, extract_zip, makedirs | ||
from ogb.lsc.utils import split_test | ||
from torch_geometric.data import HeteroData | ||
|
||
|
||
class MAG240MDataset(object): | ||
|
@@ -53,6 +54,40 @@ def download(self): | |
print('Stop download.') | ||
exit(-1) | ||
|
||
def to_pyg_hetero_data(self): | ||
data = HeteroData() | ||
path = osp.join(self.dir, 'processed', 'paper', 'node_feat.npy') | ||
# Current is not in-memory | ||
data['paper'].x = torch.from_numpy(np.load(path, mmap_mode='r')) | ||
path = osp.join(self.dir, 'processed', 'paper', 'node_label.npy') | ||
data['paper'].y = torch.from_numpy(np.load(path)) | ||
path = osp.join(self.dir, 'processed', 'paper', 'node_year.npy') | ||
data['paper'].year = torch.from_numpy(np.load(path, mmap_mode='r')) | ||
|
||
data['author'].num_nodes = self.__meta__['author'] | ||
path = osp.join(self.dir, 'processed', 'author', 'author.npy') | ||
data['author'].x = np.memmap(path, mode='r', dtype="float16", shape=(data['author'].num_nodes, self.num_paper_features)) | ||
data['institution'].num_nodes = self.__meta__['institution'] | ||
path = osp.join(self.dir, 'processed', 'institution', 'inst.npy') | ||
data['institution'].x = np.memmap(path, mode='r', dtype="float16", shape=(data['institution'].num_nodes, self.num_paper_features)) | ||
|
||
for edge_type in [('author', 'affiliated_with', 'institution'), | ||
('author', 'writes', 'paper'), | ||
('paper', 'cites', 'paper')]: | ||
name = '___'.join(edge_type) | ||
path = osp.join(self.dir, 'processed', name, 'edge_index.npy') | ||
edge_index = torch.from_numpy(np.load(path)) | ||
data[edge_type].edge_index = edge_index | ||
data[edge_type[2], f'rev_{edge_type[1]}', edge_type[0]].edge_index = edge_index.flip([0]) | ||
|
||
for f, v in [('train', 'train'), ('valid', 'val'), ('test-dev', 'test')]: | ||
idx = self.get_idx_split(f) | ||
idx = torch.from_numpy(idx) | ||
mask = torch.zeros(data['paper'].num_nodes, dtype=torch.bool) | ||
mask[idx] = True | ||
data['paper'][f'{v}_mask'] = mask | ||
return data | ||
|
||
@property | ||
def num_papers(self) -> int: | ||
return self.__meta__['paper'] | ||
|
@@ -108,15 +143,6 @@ def all_paper_year(self) -> np.ndarray: | |
path = osp.join(self.dir, 'processed', 'paper', 'node_year.npy') | ||
return np.load(path) | ||
|
||
def edge_index(self, id1: str, id2: str, | ||
id3: Optional[str] = None) -> np.ndarray: | ||
src = id1 | ||
rel, dst = (id3, id2) if id3 is None else (id2, id3) | ||
rel = self.__rels__[(src, dst)] if rel is None else rel | ||
name = f'{src}___{rel}___{dst}' | ||
path = osp.join(self.dir, 'processed', name, 'edge_index.npy') | ||
return np.load(path) | ||
|
||
def __repr__(self) -> str: | ||
return f'{self.__class__.__name__}()' | ||
|
||
|
@@ -164,6 +190,7 @@ def save_test_submission(self, input_dict: Dict, dir_path: str, mode: str): | |
|
||
if __name__ == '__main__': | ||
dataset = MAG240MDataset() | ||
data = dataset.to_pyg_hetero_data() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's test this separately? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
print(dataset) | ||
print(dataset.num_papers) | ||
print(dataset.num_authors) | ||
|
@@ -196,22 +223,25 @@ def save_test_submission(self, input_dict: Dict, dir_path: str, mode: str): | |
|
||
exit(-1) | ||
|
||
print(dataset.paper_feat.shape) | ||
print(dataset.paper_year.shape) | ||
print(dataset.paper_year[:100]) | ||
print(dataset.edge_index('author', 'paper').shape) | ||
print(dataset.edge_index('author', 'writes', 'paper').shape) | ||
print(dataset.edge_index('author', 'writes', 'paper')[:, :10]) | ||
print(data['paper'].x.shape) | ||
print(data['paper'].year.shape) | ||
print(data['paper'].year[:100]) | ||
print(data[(('author', 'writes', 'paper'))].edge_index.shape) | ||
print(data[('author', 'affiliated_with', 'institution')].edge_index.shape) | ||
print(data[('paper', 'cites', 'paper')].edge_index.shape) | ||
print(data[('author', 'writes', 'paper')].edge_index[:, :10]) | ||
print(data[('author', 'affiliated_with', 'institution')].edge_index[:, :10]) | ||
print(data[('paper', 'cites', 'paper')].edge_index[:, :10]) | ||
print('-----------------') | ||
|
||
train_idx = dataset.get_idx_split('train') | ||
val_idx = dataset.get_idx_split('valid') | ||
test_idx = dataset.get_idx_split('test-whole') | ||
print(len(train_idx) + len(val_idx) + len(test_idx)) | ||
|
||
print(dataset.paper_label[train_idx][:10]) | ||
print(dataset.paper_label[val_idx][:10]) | ||
print(dataset.paper_label[test_idx][:10]) | ||
print(dataset.paper_year[train_idx][:10]) | ||
print(dataset.paper_year[val_idx][:10]) | ||
print(dataset.paper_year[test_idx][:10]) | ||
print(data['paper'].y[train_idx][:10]) | ||
print(data['paper'].y[val_idx][:10]) | ||
print(data['paper'].y[test_idx][:10]) | ||
print(data['paper'].year[train_idx][:10]) | ||
print(data['paper'].year[val_idx][:10]) | ||
print(data['paper'].year[test_idx][:10]) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data["paper"].x = torch.from_numpy(np.load(path, mmap_mode='r'))
is from@property def paper_label(self)...
, which is called whenself.in_memory
isFalse
. So I comment here, to remind myself to enablein_memory
part.