Skip to content

Commit 3173f29

Browse files
author
ZhuangXiang
committed
add files
1 parent e5741c7 commit 3173f29

20 files changed

+1872
-1
lines changed

Diff for: .gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
**/processed/
3+
data/
4+
dump/
5+
__pycache__/

Diff for: README.md

+85-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,85 @@
1-
# GS-Meta
1+
# GS-Meta
2+
3+
This repository is the official implementation of **GS-Meta** proposed in: **Graph Sampling-based Meta-Learning for Molecular Property Prediction, IJCAI 2023**.
4+
5+
6+
<div align=center><img src="./fig/framework.png" style="zoom:50%;" />
7+
</div>
8+
9+
10+
## Environment
11+
To run the code successfully, the following dependencies need to be installed:
12+
```
13+
python 3.7
14+
torch 1.7.1
15+
rdkit 2022.9.3
16+
learn2learn 0.1.6
17+
torch_geometric 1.6.3
18+
torch_scatter 2.0.7
19+
```
20+
21+
## Step-by-step guidelines
22+
23+
### Datasets
24+
For data used in the experiments, please download `data.zip` from the [release](https://github.com/HICAI-ZJU/GS-Meta/releases) page in this repo, then extract the downloaded file and save the contents in the `data` directory.
25+
26+
### Project Overview
27+
28+
This project mainly contains the following parts.
29+
30+
```
31+
├── data # dataset files
32+
│   ├── sider
33+
│   │ ├── sider.csv
34+
│   ├── tox21
35+
│   │ ├── tox21.csv
36+
│ └── ...
37+
├── datasets
38+
│   ├── __init__.py
39+
│ └── ...
40+
├── models
41+
│   ├── __init__.py
42+
│ └── ...
43+
├── pretraiend # pretrained GNN
44+
│   ├── supervised_contextpred.pth
45+
├── args_parser.py
46+
├── explight.py
47+
├── meta_learner.py
48+
└── run.py
49+
```
50+
51+
### Running Script
52+
53+
```sh
54+
python run.py --dataset sider --n_support 10 --gpu 0
55+
```
56+
Running parameters and descriptions are as follows:
57+
58+
| Parameter | Description | Default Value | Choices |
59+
| --- | --- | --- | --- |
60+
| dataset | name of dataset | sider | `tox21, sider, muv, pcba, toxcast-APR, toxcast-ATG, toxcast-BSK, toxcast-CEETOX, toxcast-CLD, toxcast-NVS, toxcast-OT, toxcast-Tanguay, toxcast-TOX21` |
61+
| n_support | number of support molecules | 10 | `1, 10`
62+
| gpu | which GPU to use | 0 | \ |
63+
| exp_name | experiment name | None | \ |
64+
| exp_id | experiment ID | None | \ |
65+
| eval_step | evaluation interval | 100 | \ |
66+
67+
For Pre-GS-Meta, which is initialized with a pretrained GNN, the running script is:
68+
69+
70+
```sh
71+
python run.py --dataset sider --n_support 10 --gpu 0 --mol_pretrain_load_path pretrained/supervised_contextpred.pth
72+
```
73+
74+
75+
## References
76+
If you use or extend our work, please cite the paper as follows:
77+
78+
```bibtex
79+
@InProceedings{zhuang2023graph,
80+
title={Graph Sampling-based Meta-Learning for Molecular Property Prediction},
81+
author={Xiang Zhuang and Qiang Zhang and Bin Wu and Keyan Ding and Yin Fang and Huajun Chen},
82+
booktile={IJCAI},
83+
year={2023}
84+
}
85+
```

Diff for: args_parser.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import argparse
2+
3+
4+
def args_parser():
5+
parser = argparse.ArgumentParser()
6+
# exp
7+
parser.add_argument("--exp_name", default="run", type=str,
8+
help="Experiment name")
9+
parser.add_argument("--dump_path", default="dump/", type=str,
10+
help="Experiment dump path")
11+
parser.add_argument("--exp_id", default="", type=str,
12+
help="Experiment ID")
13+
parser.add_argument("--gpu", default='0', type=str)
14+
parser.add_argument("--random_seed", default=0, type=int)
15+
16+
17+
# dataset
18+
parser.add_argument("--data_root", default='data', type=str)
19+
parser.add_argument("--dataset", default='sider', type=str)
20+
# choices=['sider', 'tox21', 'muv', 'toxcast'])
21+
# mol encoder
22+
parser.add_argument("--mol_num_layer", default=5, type=int)
23+
parser.add_argument("--emb_dim", default=300, type=int)
24+
parser.add_argument("--JK", default='last', type=str)
25+
parser.add_argument("--mol_dropout", default=0.1, type=float)
26+
parser.add_argument("--mol_graph_pooling", default='mean', type=str)
27+
parser.add_argument("--mol_gnn_type", default='gin', type=str)
28+
parser.add_argument("--mol_batch_norm", default=1, type=int)
29+
parser.add_argument("--mol_pretrain_load_path", default=None)
30+
31+
# relation net
32+
parser.add_argument("--rel_layer", default=2, type=int)
33+
parser.add_argument("--rel_edge_n_layer", default=2, type=int)
34+
parser.add_argument("--rel_top_k", default=None, type=int)
35+
parser.add_argument("--rel_edge_hidden_dim", default=100, type=int)
36+
parser.add_argument("--rel_dropout", default=0.1, type=float)
37+
parser.add_argument("--rel_pre_dropout", default=0.1, type=float)
38+
parser.add_argument("--rel_nan_w", default=1., type=float)
39+
parser.add_argument("--rel_nan_type", default='nan', type=str, choices=['nan', '0', '1'])
40+
parser.add_argument("--rel_batch_norm", default=1, type=int)
41+
parser.add_argument("--rel_edge_type", default=1, type=int)
42+
43+
# maml
44+
parser.add_argument("--inner_lr", default=0.5, type=float)
45+
parser.add_argument("--meta_lr", default=1e-3, type=float)
46+
parser.add_argument("--weight_decay", default=5e-5, type=float)
47+
parser.add_argument("--second_order", default=1, type=int)
48+
parser.add_argument("--inner_update_step", default=1, type=int)
49+
parser.add_argument("--inner_tasks", default=10, type=int)
50+
51+
# few-shot
52+
parser.add_argument("--episode", default=2000, type=int)
53+
parser.add_argument("--n_support", default=10, type=int)
54+
parser.add_argument("--n_query", default=16, type=int)
55+
parser.add_argument("--n_test_tasks", default=200, type=int)
56+
parser.add_argument("--eval_step", default=100, type=int)
57+
parser.add_argument("--test_batch_size", default=128, type=int)
58+
parser.add_argument("--train_auxi_task_num", default=None, type=int)
59+
parser.add_argument("--test_auxi_task_num", default=None, type=int)
60+
61+
# contrastive
62+
parser.add_argument("--nce_t", default=0.08, type=float)
63+
parser.add_argument("--contr_w", default=0.05, type=float)
64+
# selector
65+
parser.add_argument("--pool_num", default=10, type=float)
66+
parser.add_argument("--task_lr", default=5e-4, type=float)
67+
parser.add_argument("--task_hid_dim", default=10, type=int)
68+
parser.add_argument("--task_t", default=1, type=float)
69+
args = parser.parse_args()
70+
71+
if args.rel_top_k is None:
72+
args.rel_top_k = args.n_support - 1 if args.n_support > 1 else 1
73+
# args.test_fixed_support = True if args.test_fixed_support == 1 else False
74+
return args

Diff for: dataset/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .dataset import FewshotMolDataset
2+
from .sampler import dataset_sampler

Diff for: dataset/dataset.py

+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import os
2+
import os.path as osp
3+
import re
4+
import pickle
5+
import random
6+
import tqdm
7+
8+
import torch
9+
from torch_geometric.data import (InMemoryDataset, Data, download_url,
10+
extract_zip)
11+
12+
import numpy as np
13+
14+
from rdkit import Chem, RDConfig
15+
from rdkit.Chem import AllChem, ChemicalFeatures
16+
17+
from .mol_features import allowable_features
18+
19+
20+
class FewshotMolDataset(InMemoryDataset):
21+
# Format: name: [display_name, url_name, csv_name, smiles_idx, y_idx, train_tasks, test_tasks]
22+
names = {
23+
'pcba': ['PCBA', 'pcba', 'pcba', -1, slice(0, 128), 118, 10],
24+
'muv': ['MUV', 'muv', 'muv', -1, slice(0, 17), 12, 5],
25+
'tox21': ['Tox21', 'tox21', 'tox21', -1, slice(0, 12), 9, 3],
26+
'sider': ['SIDER', 'sider', 'sider', 0, slice(1, 28), 21, 6],
27+
28+
# toxcast subtask
29+
'toxcast-APR': ['ToxCast-APR', 'toxcast-APR', 'toxcast-APR', 0, slice(1, 44), 33, 10],
30+
'toxcast-ATG': ['ToxCast-ATG', 'toxcast-ATG', 'toxcast-ATG', 0, slice(1, 147), 106, 40],
31+
'toxcast-BSK': ['ToxCast-BSK', 'toxcast-BSK', 'toxcast-BSK', 0, slice(1, 116), 84, 31],
32+
'toxcast-CEETOX': ['ToxCast-CEETOX', 'toxcast-CEETOX', 'toxcast-CEETOX', 0, slice(1, 15), 10, 4],
33+
'toxcast-CLD': ['ToxCast-CLD', 'toxcast-CLD', 'toxcast-CLD', 0, slice(1, 20), 14, 5],
34+
'toxcast-NVS': ['ToxCast-NVS', 'toxcast-NVS', 'toxcast-NVS', 0, slice(1, 140), 100, 39],
35+
'toxcast-OT': ['ToxCast-OT', 'toxcast-OT', 'toxcast-OT', 0, slice(1, 16), 11, 4],
36+
'toxcast-TOX21': ['ToxCast-TOX21', 'toxcast-TOX21', 'toxcast-TOX21', 0, slice(1, 101), 80, 20],
37+
'toxcast-Tanguay': ['ToxCast-Tanguay', 'toxcast-Tanguay', 'toxcast-Tanguay', 0, slice(1, 19), 13, 5],
38+
}
39+
40+
def __init__(self, root, name, transform=None, pre_transform=None,
41+
pre_filter=None):
42+
43+
if Chem is None:
44+
raise ImportError('`MoleculeNet` requires `rdkit`.')
45+
46+
self.name = name
47+
assert self.name in self.names.keys()
48+
super(FewshotMolDataset, self).__init__(root, transform, pre_transform, pre_filter)
49+
self.n_task_train, self.n_task_test = self.names[self.name][5], self.names[self.name][6]
50+
self.total_tasks = self.n_task_train + self.n_task_test
51+
if name != 'pcba':
52+
self.train_task_range = list(range(self.n_task_train))
53+
self.test_task_range = list(range(self.n_task_train, self.n_task_train + self.n_task_test))
54+
else:
55+
self.train_task_range = list(range(5, self.total_tasks - 5))
56+
self.test_task_range = list(range(5)) + list(range(self.total_tasks - 5, self.total_tasks))
57+
58+
self.data, self.slices = torch.load(self.processed_paths[0])
59+
self.index_list = pickle.load(open(self.processed_paths[1], 'rb'))
60+
self.y_matrix = np.load(open(self.processed_paths[2], 'rb'))
61+
62+
@property
63+
def raw_dir(self):
64+
return osp.join(self.root, self.name)
65+
66+
@property
67+
def processed_dir(self):
68+
return osp.join(self.root, self.name, 'processed')
69+
70+
@property
71+
def raw_file_names(self):
72+
return f'{self.names[self.name][2]}.csv'
73+
74+
@property
75+
def processed_file_names(self):
76+
return 'data.pt', 'index_list.pt', 'label_matrix.npz'
77+
78+
def process(self):
79+
with open(self.raw_paths[0], 'r') as f:
80+
dataset = f.read().split('\n')[1:-1]
81+
dataset = [x for x in dataset if len(x) > 0] # Filter empty lines.
82+
83+
data_list = []
84+
y_list = []
85+
data_id = 0
86+
for line in tqdm.tqdm(dataset):
87+
line = re.sub(r'\".*\"', '', line) # Replace ".*" strings.
88+
line = line.split(',')
89+
smiles = line[self.names[self.name][3]]
90+
mol = Chem.MolFromSmiles(smiles)
91+
if mol is None:
92+
continue
93+
Chem.Kekulize(mol)
94+
95+
ys = line[self.names[self.name][4]]
96+
ys = ys if isinstance(ys, list) else [ys]
97+
98+
ys = [float(y) if len(y) > 0 else float('NaN') for y in ys]
99+
y = torch.tensor(ys, dtype=torch.float).view(1, -1)
100+
y_list.append(ys)
101+
102+
xs = []
103+
for atom in mol.GetAtoms():
104+
x = []
105+
x.append(allowable_features['possible_atomic_num_list'].index(atom.GetAtomicNum()))
106+
x.append(allowable_features['possible_chirality_list'].index(atom.GetChiralTag()))
107+
xs.append(x)
108+
109+
x = torch.tensor(xs, dtype=torch.long).view(-1, 2)
110+
111+
edge_indices, edge_attrs = [], []
112+
for bond in mol.GetBonds():
113+
i = bond.GetBeginAtomIdx()
114+
j = bond.GetEndAtomIdx()
115+
116+
e = []
117+
e.append(allowable_features['possible_bonds'].index(bond.GetBondType()))
118+
e.append(allowable_features['possible_bond_dirs'].index(bond.GetBondDir()))
119+
120+
edge_indices += [[i, j], [j, i]]
121+
edge_attrs += [e, e]
122+
123+
edge_index = torch.tensor(edge_indices)
124+
edge_index = edge_index.t().to(torch.long).view(2, -1)
125+
edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 2)
126+
127+
# Sort indices.
128+
if edge_index.numel() > 0:
129+
perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
130+
edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]
131+
132+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, smiles=smiles, id=data_id)
133+
data_id += 1
134+
135+
if self.pre_filter is not None and not self.pre_filter(data):
136+
continue
137+
138+
if self.pre_transform is not None:
139+
data = self.pre_transform(data)
140+
141+
data_list.append(data)
142+
143+
y_matrix = np.array(y_list)
144+
index_list = [] # [[[],[]], [[],[]]], task-label-index
145+
for task_i in range(y_matrix.shape[1]):
146+
task_i_label_values = y_matrix[:, task_i]
147+
class1_index = np.nonzero(task_i_label_values > 0.5)[0].tolist()
148+
class0_index = np.nonzero(task_i_label_values < 0.5)[0].tolist()
149+
index_list.append([class0_index, class1_index])
150+
151+
torch.save(self.collate(data_list), self.processed_paths[0])
152+
pickle.dump(index_list, open(self.processed_paths[1], 'wb'))
153+
np.save(open(self.processed_paths[2], 'wb'), y_matrix)
154+
155+
def __repr__(self):
156+
return '{}({})'.format(self.names[self.name][0], len(self))
157+

Diff for: dataset/mol_features.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from rdkit import Chem
2+
3+
allowable_features = {
4+
'possible_atomic_num_list': list(range(1, 119)),
5+
'possible_chirality_list': [
6+
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
7+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
8+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
9+
Chem.rdchem.ChiralType.CHI_OTHER
10+
],
11+
'possible_bonds': [
12+
Chem.rdchem.BondType.SINGLE,
13+
Chem.rdchem.BondType.DOUBLE,
14+
Chem.rdchem.BondType.TRIPLE,
15+
Chem.rdchem.BondType.AROMATIC
16+
],
17+
'possible_bond_dirs': [ # only for double bond stereo information
18+
Chem.rdchem.BondDir.NONE,
19+
Chem.rdchem.BondDir.ENDUPRIGHT,
20+
Chem.rdchem.BondDir.ENDDOWNRIGHT
21+
]
22+
}

Diff for: dataset/sampler.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import numpy as np
2+
3+
4+
def sample_from_candi(candi_list, num):
5+
return np.random.choice(candi_list, num, replace=False).tolist()
6+
7+
8+
9+
def dataset_sampler(dataset, n_support, n_query, tgt_id, inductive=False):
10+
tgt_index_list = dataset.index_list[tgt_id]
11+
class0_num, class1_num = len(tgt_index_list[0]), len(tgt_index_list[1])
12+
if class0_num > n_support and class1_num > n_support:
13+
support_list_i_0 = sample_from_candi(tgt_index_list[0], n_support)
14+
support_list_i_1 = sample_from_candi(tgt_index_list[1], n_support)
15+
elif class0_num <= n_support < class1_num:
16+
# class 0 not enough
17+
support_list_i_0 = sample_from_candi(tgt_index_list[0], class0_num - 1)
18+
support_list_i_1 = sample_from_candi(tgt_index_list[1], 2 * n_support - class0_num + 1)
19+
else:
20+
support_list_i_0 = sample_from_candi(tgt_index_list[0], 2 * n_support - class1_num + 1)
21+
support_list_i_1 = sample_from_candi(tgt_index_list[1], class1_num - 1)
22+
support_list = support_list_i_0 + support_list_i_1
23+
24+
if not inductive:
25+
query_candi_i_0 = [idx for idx in tgt_index_list[0] if idx not in support_list]
26+
query_candi_i_1 = [idx for idx in tgt_index_list[1] if idx not in support_list]
27+
query_list = sample_from_candi(query_candi_i_0, 1) + sample_from_candi(query_candi_i_1, 1)
28+
query_candi = [idx for idx in query_candi_i_0 + query_candi_i_1 if idx not in query_list]
29+
query_list += sample_from_candi(query_candi, n_query - 2)
30+
else:
31+
query_list = [idx for idx in tgt_index_list[0] + tgt_index_list[1]
32+
if idx not in support_list]
33+
return dataset[support_list], dataset[query_list]

0 commit comments

Comments
 (0)