|
| 1 | +import os |
| 2 | +import os.path as osp |
| 3 | +import re |
| 4 | +import pickle |
| 5 | +import random |
| 6 | +import tqdm |
| 7 | + |
| 8 | +import torch |
| 9 | +from torch_geometric.data import (InMemoryDataset, Data, download_url, |
| 10 | + extract_zip) |
| 11 | + |
| 12 | +import numpy as np |
| 13 | + |
| 14 | +from rdkit import Chem, RDConfig |
| 15 | +from rdkit.Chem import AllChem, ChemicalFeatures |
| 16 | + |
| 17 | +from .mol_features import allowable_features |
| 18 | + |
| 19 | + |
| 20 | +class FewshotMolDataset(InMemoryDataset): |
| 21 | + # Format: name: [display_name, url_name, csv_name, smiles_idx, y_idx, train_tasks, test_tasks] |
| 22 | + names = { |
| 23 | + 'pcba': ['PCBA', 'pcba', 'pcba', -1, slice(0, 128), 118, 10], |
| 24 | + 'muv': ['MUV', 'muv', 'muv', -1, slice(0, 17), 12, 5], |
| 25 | + 'tox21': ['Tox21', 'tox21', 'tox21', -1, slice(0, 12), 9, 3], |
| 26 | + 'sider': ['SIDER', 'sider', 'sider', 0, slice(1, 28), 21, 6], |
| 27 | + |
| 28 | + # toxcast subtask |
| 29 | + 'toxcast-APR': ['ToxCast-APR', 'toxcast-APR', 'toxcast-APR', 0, slice(1, 44), 33, 10], |
| 30 | + 'toxcast-ATG': ['ToxCast-ATG', 'toxcast-ATG', 'toxcast-ATG', 0, slice(1, 147), 106, 40], |
| 31 | + 'toxcast-BSK': ['ToxCast-BSK', 'toxcast-BSK', 'toxcast-BSK', 0, slice(1, 116), 84, 31], |
| 32 | + 'toxcast-CEETOX': ['ToxCast-CEETOX', 'toxcast-CEETOX', 'toxcast-CEETOX', 0, slice(1, 15), 10, 4], |
| 33 | + 'toxcast-CLD': ['ToxCast-CLD', 'toxcast-CLD', 'toxcast-CLD', 0, slice(1, 20), 14, 5], |
| 34 | + 'toxcast-NVS': ['ToxCast-NVS', 'toxcast-NVS', 'toxcast-NVS', 0, slice(1, 140), 100, 39], |
| 35 | + 'toxcast-OT': ['ToxCast-OT', 'toxcast-OT', 'toxcast-OT', 0, slice(1, 16), 11, 4], |
| 36 | + 'toxcast-TOX21': ['ToxCast-TOX21', 'toxcast-TOX21', 'toxcast-TOX21', 0, slice(1, 101), 80, 20], |
| 37 | + 'toxcast-Tanguay': ['ToxCast-Tanguay', 'toxcast-Tanguay', 'toxcast-Tanguay', 0, slice(1, 19), 13, 5], |
| 38 | + } |
| 39 | + |
| 40 | + def __init__(self, root, name, transform=None, pre_transform=None, |
| 41 | + pre_filter=None): |
| 42 | + |
| 43 | + if Chem is None: |
| 44 | + raise ImportError('`MoleculeNet` requires `rdkit`.') |
| 45 | + |
| 46 | + self.name = name |
| 47 | + assert self.name in self.names.keys() |
| 48 | + super(FewshotMolDataset, self).__init__(root, transform, pre_transform, pre_filter) |
| 49 | + self.n_task_train, self.n_task_test = self.names[self.name][5], self.names[self.name][6] |
| 50 | + self.total_tasks = self.n_task_train + self.n_task_test |
| 51 | + if name != 'pcba': |
| 52 | + self.train_task_range = list(range(self.n_task_train)) |
| 53 | + self.test_task_range = list(range(self.n_task_train, self.n_task_train + self.n_task_test)) |
| 54 | + else: |
| 55 | + self.train_task_range = list(range(5, self.total_tasks - 5)) |
| 56 | + self.test_task_range = list(range(5)) + list(range(self.total_tasks - 5, self.total_tasks)) |
| 57 | + |
| 58 | + self.data, self.slices = torch.load(self.processed_paths[0]) |
| 59 | + self.index_list = pickle.load(open(self.processed_paths[1], 'rb')) |
| 60 | + self.y_matrix = np.load(open(self.processed_paths[2], 'rb')) |
| 61 | + |
| 62 | + @property |
| 63 | + def raw_dir(self): |
| 64 | + return osp.join(self.root, self.name) |
| 65 | + |
| 66 | + @property |
| 67 | + def processed_dir(self): |
| 68 | + return osp.join(self.root, self.name, 'processed') |
| 69 | + |
| 70 | + @property |
| 71 | + def raw_file_names(self): |
| 72 | + return f'{self.names[self.name][2]}.csv' |
| 73 | + |
| 74 | + @property |
| 75 | + def processed_file_names(self): |
| 76 | + return 'data.pt', 'index_list.pt', 'label_matrix.npz' |
| 77 | + |
| 78 | + def process(self): |
| 79 | + with open(self.raw_paths[0], 'r') as f: |
| 80 | + dataset = f.read().split('\n')[1:-1] |
| 81 | + dataset = [x for x in dataset if len(x) > 0] # Filter empty lines. |
| 82 | + |
| 83 | + data_list = [] |
| 84 | + y_list = [] |
| 85 | + data_id = 0 |
| 86 | + for line in tqdm.tqdm(dataset): |
| 87 | + line = re.sub(r'\".*\"', '', line) # Replace ".*" strings. |
| 88 | + line = line.split(',') |
| 89 | + smiles = line[self.names[self.name][3]] |
| 90 | + mol = Chem.MolFromSmiles(smiles) |
| 91 | + if mol is None: |
| 92 | + continue |
| 93 | + Chem.Kekulize(mol) |
| 94 | + |
| 95 | + ys = line[self.names[self.name][4]] |
| 96 | + ys = ys if isinstance(ys, list) else [ys] |
| 97 | + |
| 98 | + ys = [float(y) if len(y) > 0 else float('NaN') for y in ys] |
| 99 | + y = torch.tensor(ys, dtype=torch.float).view(1, -1) |
| 100 | + y_list.append(ys) |
| 101 | + |
| 102 | + xs = [] |
| 103 | + for atom in mol.GetAtoms(): |
| 104 | + x = [] |
| 105 | + x.append(allowable_features['possible_atomic_num_list'].index(atom.GetAtomicNum())) |
| 106 | + x.append(allowable_features['possible_chirality_list'].index(atom.GetChiralTag())) |
| 107 | + xs.append(x) |
| 108 | + |
| 109 | + x = torch.tensor(xs, dtype=torch.long).view(-1, 2) |
| 110 | + |
| 111 | + edge_indices, edge_attrs = [], [] |
| 112 | + for bond in mol.GetBonds(): |
| 113 | + i = bond.GetBeginAtomIdx() |
| 114 | + j = bond.GetEndAtomIdx() |
| 115 | + |
| 116 | + e = [] |
| 117 | + e.append(allowable_features['possible_bonds'].index(bond.GetBondType())) |
| 118 | + e.append(allowable_features['possible_bond_dirs'].index(bond.GetBondDir())) |
| 119 | + |
| 120 | + edge_indices += [[i, j], [j, i]] |
| 121 | + edge_attrs += [e, e] |
| 122 | + |
| 123 | + edge_index = torch.tensor(edge_indices) |
| 124 | + edge_index = edge_index.t().to(torch.long).view(2, -1) |
| 125 | + edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 2) |
| 126 | + |
| 127 | + # Sort indices. |
| 128 | + if edge_index.numel() > 0: |
| 129 | + perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort() |
| 130 | + edge_index, edge_attr = edge_index[:, perm], edge_attr[perm] |
| 131 | + |
| 132 | + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, smiles=smiles, id=data_id) |
| 133 | + data_id += 1 |
| 134 | + |
| 135 | + if self.pre_filter is not None and not self.pre_filter(data): |
| 136 | + continue |
| 137 | + |
| 138 | + if self.pre_transform is not None: |
| 139 | + data = self.pre_transform(data) |
| 140 | + |
| 141 | + data_list.append(data) |
| 142 | + |
| 143 | + y_matrix = np.array(y_list) |
| 144 | + index_list = [] # [[[],[]], [[],[]]], task-label-index |
| 145 | + for task_i in range(y_matrix.shape[1]): |
| 146 | + task_i_label_values = y_matrix[:, task_i] |
| 147 | + class1_index = np.nonzero(task_i_label_values > 0.5)[0].tolist() |
| 148 | + class0_index = np.nonzero(task_i_label_values < 0.5)[0].tolist() |
| 149 | + index_list.append([class0_index, class1_index]) |
| 150 | + |
| 151 | + torch.save(self.collate(data_list), self.processed_paths[0]) |
| 152 | + pickle.dump(index_list, open(self.processed_paths[1], 'wb')) |
| 153 | + np.save(open(self.processed_paths[2], 'wb'), y_matrix) |
| 154 | + |
| 155 | + def __repr__(self): |
| 156 | + return '{}({})'.format(self.names[self.name][0], len(self)) |
| 157 | + |
0 commit comments