Skip to content

Commit de34e15

Browse files
authored
fix metapath2vec on custom datasets (dmlc#1499)
1 parent 16561a2 commit de34e15

File tree

4 files changed

+38
-14
lines changed

4 files changed

+38
-14
lines changed

examples/pytorch/metapath2vec/README.md

+20-8
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,26 @@ Dependencies
1010

1111
How to run the code
1212
-----
13-
Run with the following procedures:
14-
15-
1, Run sampler.py on your graph dataset. Note that: the input text file should be list of mappings so you probably need to preprocess your graph dataset. Files with sample format are available in "net_dbis" file. Of course you could also use your own metapath sampler implementation.
16-
17-
2, Run the following command:
18-
```bash
19-
python metapath2vec.py --download "where/you/want/to/download" --output_file "your_output_file_path"
20-
```
13+
Run with either of the following procedures:
14+
15+
* Running with default AMiner dataset:
16+
1. Directly run the following command:
17+
18+
```bash
19+
python metapath2vec.py --aminer --path "where/you/want/to/download" --output_file "your_model_output_path"
20+
```
21+
* Running with another AMiner-like dataset
22+
1. Prepare the data in the same format as the ones of AMiner and DBIS in Section B of [Author's code repo](https://ericdongyx.github.io/metapath2vec/m2v.html).
23+
2. Run `sampler.py` on your graph dataset with, for instance,
24+
25+
```bash
26+
python sampler.py net_dbis
27+
```
28+
3. Run the following command:
29+
30+
```bash
31+
python metapath2vec.py --path net_dbis/output_path.txt --output_file "your_model_output_path"
32+
```
2133
2234
Tips: Change num_workers based on your GPU instances; Running 3 or 4 epochs is actually enough.
2335

examples/pytorch/metapath2vec/download.py

+8
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,11 @@ def _download_and_extract(self, path, filename):
4444
with zipfile.ZipFile(fn) as zf:
4545
zf.extractall(path)
4646
print('Unzip finished.')
47+
48+
49+
class CustomDataset(object):
50+
"""
51+
Custom dataset generated by sampler.py (e.g. NetDBIS)
52+
"""
53+
def __init__(self, path):
54+
self.fn = path

examples/pytorch/metapath2vec/metapath2vec.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,16 @@
77

88
from reading_data import DataReader, Metapath2vecDataset
99
from model import SkipGramModel
10+
from download import AminerDataset, CustomDataset
1011

1112

1213
class Metapath2VecTrainer:
1314
def __init__(self, args):
14-
self.data = DataReader(args.download, args.min_count, args.care_type)
15+
if args.aminer:
16+
dataset = AminerDataset(args.path)
17+
else:
18+
dataset = CustomDataset(args.path)
19+
self.data = DataReader(dataset, args.min_count, args.care_type)
1520
dataset = Metapath2vecDataset(self.data, args.window_size)
1621
self.dataloader = DataLoader(dataset, batch_size=args.batch_size,
1722
shuffle=True, num_workers=args.num_workers, collate_fn=dataset.collate)
@@ -60,7 +65,8 @@ def train(self):
6065
if __name__ == '__main__':
6166
parser = argparse.ArgumentParser(description="Metapath2vec")
6267
#parser.add_argument('--input_file', type=str, help="input_file")
63-
parser.add_argument('--download', type=str, help="download_path")
68+
parser.add_argument('--aminer', action='store_true', help='Use AMiner dataset')
69+
parser.add_argument('--path', type=str, help="input_path")
6470
parser.add_argument('--output_file', type=str, help='output_file')
6571
parser.add_argument('--dim', default=128, type=int, help="embedding dimensions")
6672
parser.add_argument('--window_size', default=7, type=int, help="context window size")

examples/pytorch/metapath2vec/reading_data.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
class DataReader:
88
NEGATIVE_TABLE_SIZE = 1e8
99

10-
def __init__(self, download, min_count, care_type):
10+
def __init__(self, dataset, min_count, care_type):
1111

1212
self.negatives = []
1313
self.discards = []
@@ -18,9 +18,7 @@ def __init__(self, download, min_count, care_type):
1818
self.sentences_count = 0
1919
self.token_count = 0
2020
self.word_frequency = dict()
21-
self.download = download
22-
FB = AminerDataset(self.download)
23-
self.inputFileName = FB.fn
21+
self.inputFileName = dataset.fn
2422
self.read_words(min_count)
2523
self.initTableNegatives()
2624
self.initTableDiscards()

0 commit comments

Comments
 (0)