-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathextract_format_dataset.py
289 lines (240 loc) · 12.2 KB
/
extract_format_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import argparse
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator
from sqlalchemy import create_engine
from sqlalchemy.sql import text
import tables as tb
from tables.atom import ObjectAtom
import json
def parse_args():
parser = argparse.ArgumentParser(description="Process ChEMBL data for multitask learning.")
parser.add_argument(
"--chembl_version", type=int, required=True, help="Version of ChEMBL database. (Required)"
)
parser.add_argument(
"--fp_size", type=int, default=1024, help="Size of the fingerprints. Default: 1024"
)
parser.add_argument(
"--radius", type=int, default=2, help="Radius for Morgan fingerprints. Default: 2"
)
parser.add_argument(
"--active_mols", type=int, default=100, help="Minimum number of active molecules per target. Default: 100"
)
parser.add_argument(
"--inactive_mols", type=int, default=100, help="Minimum number of inactive molecules per target. Default: 100"
)
parser.add_argument(
"--output_dir", type=str, default=".", help="Directory to save output files. Default: current directory"
)
parser.add_argument(
"--protein_family", choices=['kinase', 'gpcr'], default=None,
help="Protein family to include in the data extraction. Options: kinase, gpcr. Default: include all families."
)
return parser.parse_args()
def fetch_data(engine, protein_family):
"""
Fetch relevant biological activity data from the database.
This function retrieves data linking molecules, their targets, and activity metrics
from a ChEMBL-like database. The resulting DataFrame contains information about
compounds, targets, and activity values filtered for high-quality data.
Args:
engine: A SQLAlchemy database engine to connect to the database.
Returns:
pd.DataFrame: A DataFrame containing the following columns:
- doc_id: The document ID for the activity data.
- standard_value: The activity measurement value (e.g., IC50, EC50) in nanomolar (nM) units.
- molregno: The unique ChEMBL molecule registry number.
- canonical_smiles: The canonical SMILES representation of the molecule.
- tid: The target ID.
- target_chembl_id: The ChEMBL ID for the target protein.
- protein_class_desc: A description of the protein class for the target.
"""
query = """
SELECT
activities.doc_id AS doc_id,
activities.standard_value AS standard_value,
molecule_hierarchy.parent_molregno AS molregno,
compound_structures.canonical_smiles AS canonical_smiles,
target_dictionary.tid AS tid,
target_dictionary.chembl_id AS target_chembl_id,
protein_classification.protein_class_desc AS protein_class_desc
FROM activities
JOIN assays ON activities.assay_id = assays.assay_id
JOIN target_dictionary ON assays.tid = target_dictionary.tid
JOIN target_components ON target_dictionary.tid = target_components.tid
JOIN component_class ON target_components.component_id = component_class.component_id
JOIN protein_classification ON component_class.protein_class_id = protein_classification.protein_class_id
JOIN molecule_dictionary ON activities.molregno = molecule_dictionary.molregno
JOIN molecule_hierarchy ON molecule_dictionary.molregno = molecule_hierarchy.molregno
JOIN compound_structures ON molecule_hierarchy.parent_molregno = compound_structures.molregno
WHERE
activities.standard_units = 'nM' AND
activities.standard_type IN ('EC50', 'IC50', 'Ki', 'Kd', 'XC50', 'AC50', 'Potency') AND
activities.data_validity_comment IS NULL AND
activities.standard_relation IN ('=', '<') AND
activities.potential_duplicate = 0 AND
assays.confidence_score >= 8 AND
target_dictionary.target_type = 'SINGLE PROTEIN'
"""
if protein_family:
family_map = {
'kinase': 'enzyme kinase protein kinase',
'gpcr': 'membrane receptor 7tm',
}
# Construct family filter clause
query += f" AND protein_classification.protein_class_desc LIKE '%{family_map[protein_family]}%'"
with engine.connect() as conn:
df = pd.read_sql(text(query), conn, dtype_backend="pyarrow")
# Sort data by activity value (standard_value), molecule (molregno), and target (tid)
df = df.sort_values(by=["standard_value", "molregno", "tid"], ascending=True)
# Remove duplicate entries for each molecule-target pair, keeping the entry with the lowest standard_value
df = df.drop_duplicates(subset=["molregno", "tid"], keep="first")
return df
def set_active(row):
"""
Determine if a molecule is active based on activity thresholds for protein families.
Uses IDG protein family activity thresholds:
- Kinases: <= 30 nM
- GPCRs (G-Protein-Coupled Receptors): <= 100 nM
- Nuclear Receptors: <= 100 nM
- Ion Channels: <= 10 μM
- Non-IDG Family Targets: <= 1 μM
See: https://druggablegenome.net/IDGProteinFamilies
Args:
row (pd.Series): A row from a DataFrame containing 'standard_value' and 'protein_class_desc'.
Returns:
int: 1 if active, 0 if inactive.
"""
standard_value = row["standard_value"]
protein_class = row["protein_class_desc"]
active = 0
if standard_value is not pd.NA:
# General threshold for activity
if standard_value <= 1000:
active = 1
# Additional thresholds for specific protein families
if "ion channel" in protein_class and standard_value <= 10000:
active = 1
if "enzyme kinase protein kinase" in protein_class and standard_value > 30:
active = 0
if "transcription factor nuclear receptor" in protein_class and standard_value > 100:
active = 0
if "membrane receptor 7tm" in protein_class and standard_value > 100:
active = 0
return active
def filter_targets(df, active_mols, inactive_mols):
"""
Filter targets based on activity and occurrence criteria.
The filtering steps:
1. Calculate activity for each molecule using `set_active`.
2. Keep targets with at least `active_mols` active molecules.
3. Keep targets with at least `inactive_mols` inactive molecules.
4. Ensure targets appear in at least 2 different documents.
Args:
df (pd.DataFrame): DataFrame containing molecule and target data.
active_mols (int): Minimum number of active molecules required for a target.
inactive_mols (int): Minimum number of inactive molecules required for a target.
Returns:
Tuple[pd.DataFrame, set]:
- Filtered DataFrame containing targets that meet all criteria.
- Set of target ChEMBL IDs that pass the filters.
"""
# Determine activity for all rows
df["active"] = df.apply(set_active, axis=1)
# Filter targets with enough active molecules
acts = df[df["active"] == 1].groupby(["target_chembl_id"]).agg("count")
acts = acts[acts["molregno"] >= active_mols].reset_index()["target_chembl_id"]
# Filter targets with enough inactive molecules
inacts = df[df["active"] == 0].groupby(["target_chembl_id"]).agg("count")
inacts = inacts[inacts["molregno"] >= inactive_mols].reset_index()["target_chembl_id"]
# Filter targets appearing in at least two different documents
docs = df.drop_duplicates(subset=["doc_id", "target_chembl_id"])
docs = docs.groupby(["target_chembl_id"]).agg("count")
docs = docs[docs["doc_id"] >= 2.0].reset_index()["target_chembl_id"]
# Intersect all criteria to get the final set of targets
t_keep = set(acts).intersection(set(inacts)).intersection(set(docs))
# Return the filtered DataFrame and the set of target ChEMBL IDs
return df[df["target_chembl_id"].isin(t_keep)], t_keep
def calc_fp(smiles, mfpgen):
"""
Calculate the molecular fingerprint for a given SMILES string.
Args:
smiles (str): The SMILES representation of the molecule.
mfpgen: An RDKit fingerprint generator object.
Returns:
np.ndarray: A NumPy array representing the molecular fingerprint.
"""
mol = Chem.MolFromSmiles(smiles)
fp = mfpgen.GetFingerprint(mol)
a = np.zeros((0,), dtype=np.float32)
Chem.DataStructs.ConvertToNumpyArray(fp, a)
return a
def save_to_h5(mt_df, descs, output_file):
"""
Save molecular data, fingerprints, and labels into an HDF5 file.
This function stores:
- Molecule IDs (molregno)
- Molecular fingerprints
- Target ChEMBL IDs
- Label matrix
- Task weights for multi-task training
Args:
mt_df (pd.DataFrame): A DataFrame containing molecule data, labels, and target ChEMBL IDs.
descs (np.ndarray): Array of molecular fingerprints.
output_file (str): Path to the output HDF5 file.
Returns:
None
"""
with tb.open_file(output_file, mode="w") as t_file:
# Compression filters for efficient storage
filters = tb.Filters(complib="blosc", complevel=5)
# Save molecule IDs (molregno) as a variable-length array
tatom = ObjectAtom()
cids = t_file.create_vlarray(t_file.root, "molregnos", atom=tatom)
for cid in mt_df["molregno"].values:
cids.append(cid)
# Save molecular fingerprints (fps) as a compressed array
fatom = tb.Atom.from_dtype(descs.dtype)
fps = t_file.create_carray(t_file.root, "fps", fatom, descs.shape, filters=filters)
fps[:] = descs
# Remove unused columns from the DataFrame (molregno and canonical_smiles)
del mt_df["molregno"]
del mt_df["canonical_smiles"]
# Save target ChEMBL IDs as a variable-length array
tatom = ObjectAtom()
tcids = t_file.create_vlarray(t_file.root, "target_chembl_ids", atom=tatom)
for tcid in mt_df.columns.values:
tcids.append(tcid)
# Save label matrix (task labels) as a compressed array
labs = t_file.create_carray(t_file.root, "labels", fatom, mt_df.values.shape, filters=filters)
labs[:] = mt_df.values
# Each task's loss will be weighted inversely proportional to the number of data points for that task
# Reference: https://ml.jku.at/publications/2014/NIPS2014f.pdf
weights = [1 / mt_df[mt_df[col] >= 0.0].shape[0] for col in mt_df.columns.values]
weights = np.array(weights)
ws = t_file.create_carray(t_file.root, "weights", fatom, weights.shape)
ws[:] = weights
if __name__ == "__main__":
args = parse_args()
engine = create_engine(f"sqlite:///chembl_{args.chembl_version}.db")
df = fetch_data(engine, args.protein_family)
pf_name = args.protein_family if args.protein_family else "all"
df.to_csv(f"{args.output_dir}/chembl_{args.chembl_version}_{pf_name}_activity_data.csv", index=False)
activities, t_keep = filter_targets(df, args.active_mols, args.inactive_mols)
activities.to_csv(f"{args.output_dir}/chembl_{args.chembl_version}_{pf_name}_activity_data_filtered.csv", index=False)
mt_df = activities.pivot(index="molregno", columns="target_chembl_id", values="active")
mt_df = mt_df.fillna(-1).reset_index()
structs = activities[["molregno", "canonical_smiles"]].drop_duplicates(subset="molregno")
# keep only compounds that RDKit can parse
structs = structs[structs["canonical_smiles"].apply(lambda smi: Chem.MolFromSmiles(smi) is not None)]
mt_df = pd.merge(structs, mt_df, how="inner", on="molregno")
mt_df.to_csv(f"{args.output_dir}/chembl_{args.chembl_version}_{pf_name}_multi_task_data.csv", index=False)
mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=args.radius,fpSize=args.fp_size)
descs = np.asarray([calc_fp(smi, mfpgen) for smi in mt_df["canonical_smiles"]], dtype=np.float32)
save_to_h5(mt_df, descs, f"{args.output_dir}/mt_data_{args.chembl_version}_{pf_name}.h5")
# Check that h5 file opens and save targets to a json file
with tb.open_file(f"{args.output_dir}/mt_data_{args.chembl_version}_{pf_name}.h5", mode="r") as t_file:
with open(f"{args.output_dir}/targets_{args.chembl_version}_{pf_name}.json", "w") as f:
json.dump(t_file.root.target_chembl_ids[:], f)