|
| 1 | +import ogb |
| 2 | +from ogb.lsc import MAG240MDataset |
| 3 | +import tqdm |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | +import dgl |
| 7 | +import dgl.function as fn |
| 8 | +import argparse |
| 9 | +import os |
| 10 | + |
| 11 | +parser = argparse.ArgumentParser() |
| 12 | +parser.add_argument('--rootdir', type=str, default='.', help='Directory to download the OGB dataset.') |
| 13 | +parser.add_argument('--author-output-path', type=str, help='Path to store the author features.') |
| 14 | +parser.add_argument('--inst-output-path', type=str, |
| 15 | + help='Path to store the institution features.') |
| 16 | +parser.add_argument('--graph-output-path', type=str, help='Path to store the graph.') |
| 17 | +parser.add_argument('--graph-format', type=str, default='csc', help='Graph format (coo, csr or csc).') |
| 18 | +parser.add_argument('--graph-as-homogeneous', action='store_true', help='Store the graph as DGL homogeneous graph.') |
| 19 | +parser.add_argument('--full-output-path', type=str, |
| 20 | + help='Path to store features of all nodes. Effective only when graph is homogeneous.') |
| 21 | +args = parser.parse_args() |
| 22 | + |
| 23 | +print('Building graph') |
| 24 | +dataset = MAG240MDataset(root=args.rootdir) |
| 25 | +ei_writes = dataset.edge_index('author', 'writes', 'paper') |
| 26 | +ei_cites = dataset.edge_index('paper', 'paper') |
| 27 | +ei_affiliated = dataset.edge_index('author', 'institution') |
| 28 | + |
| 29 | +# We sort the nodes starting with the papers, then the authors, then the institutions. |
| 30 | +author_offset = 0 |
| 31 | +inst_offset = author_offset + dataset.num_authors |
| 32 | +paper_offset = inst_offset + dataset.num_institutions |
| 33 | + |
| 34 | +g = dgl.heterograph({ |
| 35 | + ('author', 'write', 'paper'): (ei_writes[0], ei_writes[1]), |
| 36 | + ('paper', 'write-by', 'author'): (ei_writes[1], ei_writes[0]), |
| 37 | + ('author', 'affiliate-with', 'institution'): (ei_affiliated[0], ei_affiliated[1]), |
| 38 | + ('institution', 'affiliate', 'author'): (ei_affiliated[1], ei_affiliated[0]), |
| 39 | + ('paper', 'cite', 'paper'): (np.concatenate([ei_cites[0], ei_cites[1]]), np.concatenate([ei_cites[1], ei_cites[0]])) |
| 40 | + }) |
| 41 | + |
| 42 | +paper_feat = dataset.paper_feat |
| 43 | +author_feat = np.memmap(args.author_output_path, mode='w+', dtype='float16', shape=(dataset.num_authors, dataset.num_paper_features)) |
| 44 | +inst_feat = np.memmap(args.inst_output_path, mode='w+', dtype='float16', shape=(dataset.num_institutions, dataset.num_paper_features)) |
| 45 | + |
| 46 | +# Iteratively process author features along the feature dimension. |
| 47 | +BLOCK_COLS = 16 |
| 48 | +with tqdm.trange(0, dataset.num_paper_features, BLOCK_COLS) as tq: |
| 49 | + for start in tq: |
| 50 | + tq.set_postfix_str('Reading paper features...') |
| 51 | + g.nodes['paper'].data['x'] = torch.FloatTensor(paper_feat[:, start:start + BLOCK_COLS].astype('float32')) |
| 52 | + # Compute author features... |
| 53 | + tq.set_postfix_str('Computing author features...') |
| 54 | + g.update_all(fn.copy_u('x', 'm'), fn.mean('m', 'x'), etype='write-by') |
| 55 | + # Then institution features... |
| 56 | + tq.set_postfix_str('Computing institution features...') |
| 57 | + g.update_all(fn.copy_u('x', 'm'), fn.mean('m', 'x'), etype='affiliate-with') |
| 58 | + tq.set_postfix_str('Writing author features...') |
| 59 | + author_feat[:, start:start + BLOCK_COLS] = g.nodes['author'].data['x'].numpy().astype('float16') |
| 60 | + tq.set_postfix_str('Writing institution features...') |
| 61 | + inst_feat[:, start:start + BLOCK_COLS] = g.nodes['institution'].data['x'].numpy().astype('float16') |
| 62 | + del g.nodes['paper'].data['x'] |
| 63 | + del g.nodes['author'].data['x'] |
| 64 | + del g.nodes['institution'].data['x'] |
| 65 | +author_feat.flush() |
| 66 | +inst_feat.flush() |
| 67 | + |
| 68 | +# Convert to homogeneous if needed. (The RGAT baseline needs homogeneous graph) |
| 69 | +if args.graph_as_homogeneous: |
| 70 | + # Process graph |
| 71 | + g = dgl.to_homogeneous(g) |
| 72 | + # DGL ensures that nodes with the same type are put together with the order preserved. |
| 73 | + # DGL also ensures that the node types are sorted in ascending order. |
| 74 | + assert torch.equal( |
| 75 | + g.ndata[dgl.NTYPE], |
| 76 | + torch.cat([torch.full((dataset.num_authors,), 0), |
| 77 | + torch.full((dataset.num_institutions,), 1), |
| 78 | + torch.full((dataset.num_papers,), 2)])) |
| 79 | + assert torch.equal( |
| 80 | + g.ndata[dgl.NID], |
| 81 | + torch.cat([torch.arange(dataset.num_authors), |
| 82 | + torch.arange(dataset.num_institutions), |
| 83 | + torch.arange(dataset.num_papers)])) |
| 84 | + g.edata['etype'] = g.edata[dgl.ETYPE].byte() |
| 85 | + del g.edata[dgl.ETYPE] |
| 86 | + del g.ndata[dgl.NTYPE] |
| 87 | + del g.ndata[dgl.NID] |
| 88 | + |
| 89 | + # Process feature |
| 90 | + full_feat = np.memmap( |
| 91 | + args.full_output_path, mode='w+', dtype='float16', |
| 92 | + shape=(dataset.num_authors + dataset.num_institutions + dataset.num_papers, dataset.num_paper_features)) |
| 93 | + BLOCK_ROWS = 100000 |
| 94 | + for start in tqdm.trange(0, dataset.num_authors, BLOCK_ROWS): |
| 95 | + end = min(dataset.num_authors, start + BLOCK_ROWS) |
| 96 | + full_feat[author_offset + start:author_offset + end] = author_feat[start:end] |
| 97 | + for start in tqdm.trange(0, dataset.num_institutions, BLOCK_ROWS): |
| 98 | + end = min(dataset.num_institutions, start + BLOCK_ROWS) |
| 99 | + full_feat[inst_offset + start:inst_offset + end] = inst_feat[start:end] |
| 100 | + for start in tqdm.trange(0, dataset.num_papers, BLOCK_ROWS): |
| 101 | + end = min(dataset.num_papers, start + BLOCK_ROWS) |
| 102 | + full_feat[paper_offset + start:paper_offset + end] = paper_feat[start:end] |
| 103 | + |
| 104 | +# Convert the graph to the given format and save. (The RGAT baseline needs CSC graph) |
| 105 | +g = g.formats(args.graph_format) |
| 106 | +dgl.save_graphs(args.graph_output_path, g) |
0 commit comments