Skip to content

Commit

Permalink
Added check to keep consistency with alphabet (#143)
Browse files Browse the repository at this point in the history
* Added check to keep consistency with alphabet

* Fixed pre-commit maybe

* Now we play find the correct langchain version
  • Loading branch information
whitead authored Dec 4, 2023
1 parent 6787389 commit 89d7626
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 66 deletions.
14 changes: 6 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@ repos:
- id: end-of-file-fixer
- id: mixed-line-ending
- repo: https://github.com/psf/black
rev: "22.3.0"
rev: "23.11.0"
hooks:
- id: black
- repo: https://github.com/tomcatling/black-nb
rev: "0.7"
hooks:
- id: black-nb
description: strip output and black source
- id: black
additional_dependencies: ['black[jupyter]']
args: ["--clear-output"]
- repo: https://github.com/kynan/nbstripout
rev: "0.6.0"
hooks:
- id: nbstripout
48 changes: 47 additions & 1 deletion exmol/exmol.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from functools import reduce
from functools import reduce, lru_cache
import inspect
from typing import *
import io
Expand All @@ -9,6 +9,7 @@
from matplotlib.patches import Rectangle, FancyBboxPatch # type: ignore
from matplotlib.offsetbox import AnnotationBbox # type: ignore
import matplotlib as mpl # type: ignore
import re
import selfies as sf # type: ignore
import tqdm # type: ignore
import textwrap # type: ignore
Expand Down Expand Up @@ -367,6 +368,38 @@ def get_basic_alphabet() -> Set[str]:
return a


# The code below checks for accidental addition of symbols outside of the alphabet.
# the only way this can happen is if the character indicating a ring
# length is mutated to appear eleswhere. The ring length symbol
# is always a plain uncharged element.


def _alphabet_to_elements(alphabet: List[str]) -> Set[str]:
"""Converts SELFIES alphabet to element symbols"""
symbols = []
for s in alphabet:
s = s.replace("[", "").replace("]", "")
if s.isalpha():
symbols.append(s)
return set(symbols)


def _check_alphabet_consistency(
smiles: str, alphabet_symbols: Set[str], check=False
) -> True:
"""Checks if SMILES only contains tokens from alphabet"""

alphabet_symbols = _alphabet_to_elements(set(alphabet_symbols))
# find all elements in smiles (Upper alpha or upper alpha followed by lower alpha)
smiles_symbols = set(re.findall(r"[A-Z][a-z]?", smiles))
if check and not smiles_symbols.issubset(alphabet_symbols):
# show which symbols are not in alphabet
raise ValueError(
"symbols not in alphabet" + smiles_symbols.difference(alphabet_symbols)
)
return smiles_symbols.issubset(alphabet_symbols)


def run_stoned(
start_smiles: str,
fp_type: str = "ECFP4",
Expand All @@ -392,6 +425,9 @@ def run_stoned(
alphabet = get_basic_alphabet()
if type(alphabet) == set:
alphabet = list(alphabet)
alphabet_symbols = _alphabet_to_elements(alphabet)
# make sure starting smiles is consistent with alphabet
_ = _check_alphabet_consistency(start_smiles, alphabet_symbols, check=True)
num_mutation_ls = list(range(min_mutations, max_mutations + 1))

start_mol = smi2mol(start_smiles)
Expand All @@ -418,6 +454,16 @@ def run_stoned(
)
# Convert back to SMILES:
smiles_back = [sf.decoder(x) for x in selfies_mut]
# check if smiles are consistent with alphabet and downslect
selfies_mut, smiles_back = zip(
*[
(s, sm)
for s, sm in zip(selfies_mut, smiles_back)
if _check_alphabet_consistency(sm, alphabet_symbols)
]
)
selfies_mut, smiles_back = list(selfies_mut), list(smiles_back)

all_smiles_collect = all_smiles_collect + smiles_back
all_selfies_collect = all_selfies_collect + selfies_mut
if _pbar:
Expand Down
1 change: 0 additions & 1 deletion exmol/stoned/stoned.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,6 @@ def get_mutated_SELFIES(selfies_ls, num_mutations, alphabet):
for _ in range(num_mutations):
selfie_ls_mut_ls = []
for str_ in selfies_ls:

str_chars = get_selfie_chars(str_)
max_molecules_len = len(str_chars) + num_mutations

Expand Down
2 changes: 1 addition & 1 deletion exmol/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.0.3"
__version__ = "3.0.4"
8 changes: 2 additions & 6 deletions paper2_LIME/RF-lime.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {},
"outputs": [],
"source": [
"# make object that can compute descriptors\n",
Expand Down Expand Up @@ -249,9 +247,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"x = wls_attr.keys()\n",
Expand Down
44 changes: 11 additions & 33 deletions paper2_LIME/Solubility-RNN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {},
"outputs": [],
"source": [
"smiles = list(soldata[\"SMILES\"])\n",
Expand Down Expand Up @@ -234,9 +232,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {},
"outputs": [],
"source": [
"# now get sequences\n",
Expand Down Expand Up @@ -270,9 +266,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"model = tf.keras.Sequential()\n",
Expand All @@ -299,9 +293,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {},
"outputs": [],
"source": [
"model.compile(tf.optimizers.Adam(1e-3), loss=\"mean_squared_error\")\n",
Expand Down Expand Up @@ -411,9 +403,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {},
"outputs": [],
"source": [
"# Make sure SMILES doesn't contain multiple fragments\n",
Expand All @@ -434,9 +424,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import display, SVG\n",
Expand Down Expand Up @@ -509,9 +497,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"# Inspect space\n",
Expand All @@ -533,9 +519,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"fkw = {\"figsize\": (6, 4)}\n",
Expand Down Expand Up @@ -686,9 +670,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {},
"outputs": [],
"source": [
"# Get subsets and calculate lime importances - subsample - get rank correlation\n",
Expand Down Expand Up @@ -747,9 +729,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"# Mutation\n",
Expand All @@ -773,9 +753,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"# Alphabet\n",
Expand Down
4 changes: 1 addition & 3 deletions paper2_LIME/Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,7 @@
"cell_type": "code",
"execution_count": null,
"id": "ed252205",
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"exmol.lime_explain(space)\n",
Expand Down
10 changes: 3 additions & 7 deletions paper3_Scents/GNNModelTrainingAndEvaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "gZeyyFPJbYxi",
"outputId": "2ac9b9a9-ca8b-4608-d3a4-a2c8f937f71f",
"scrolled": true
"outputId": "2ac9b9a9-ca8b-4608-d3a4-a2c8f937f71f"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -562,9 +561,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"# Train model\n",
Expand Down Expand Up @@ -1041,8 +1038,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BBgB_3tPbYxx",
"scrolled": true
"id": "BBgB_3tPbYxx"
},
"outputs": [],
"source": [
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@ pre-commit
mypy
pytest
pytest-cov
click
openai
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"skunk >= 0.4.0",
"importlib-resources",
"synspace",
"langchain",
"langchain==0.0.343",
],
test_suite="tests",
long_description=long_description,
Expand Down
20 changes: 16 additions & 4 deletions tests/test_exmol.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def test_sanitize_smiles_chiral():
assert "@" in result[1]


# TODO let STONED people write these when they finish their repo
def test_run_stoned():
result = exmol.run_stoned(
"N#CC=CC(C(=O)NCC1=CC=CC=C1C(=O)N)(C)CC2=CC=C(F)C=C2CC",
Expand Down Expand Up @@ -199,14 +198,27 @@ def test_run_custom():

def test_run_stones_alphabet():
result = exmol.run_stoned(
"N#CC=CC(C(=O)NCC1=CC=CC=C1C(=O)N)(C)CC2=CC=C(F)C=C2CC",
num_samples=10,
max_mutations=1,
"C1=CC=C(C=C1)C2=CC=CC=C2",
num_samples=25,
max_mutations=3,
alphabet=["[C]", "[O]"],
)
# Can get duplicates
assert len(result[0]) >= 0

# check no other characters
for mol in result[0]:
print(["C", "O", "#", "(", ")"] + list([str(i) for i in range(10)]))
print(mol)
assert not any(
[
c
not in ["C", "O", "#", "(", ")", "="]
+ list([str(i) for i in range(10)])
for c in mol
]
)


def test_sample():
def model(s, se):
Expand Down

0 comments on commit 89d7626

Please sign in to comment.