From 89d7626fe160de1d7dd556dba79c403468c94980 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Mon, 4 Dec 2023 10:03:57 -0800 Subject: [PATCH] Added check to keep consistency with alphabet (#143) * Added check to keep consistency with alphabet * Fixed pre-commit maybe * Now we play find the correct langchain version --- .pre-commit-config.yaml | 14 +++--- exmol/exmol.py | 48 ++++++++++++++++++- exmol/stoned/stoned.py | 1 - exmol/version.py | 2 +- paper2_LIME/RF-lime.ipynb | 8 +--- paper2_LIME/Solubility-RNN.ipynb | 44 +++++------------ paper2_LIME/Tutorial.ipynb | 4 +- .../GNNModelTrainingAndEvaluation.ipynb | 10 ++-- requirements.txt | 1 - setup.py | 2 +- tests/test_exmol.py | 20 ++++++-- 11 files changed, 88 insertions(+), 66 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bc2db0d9..3416cbc7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,13 +7,11 @@ repos: - id: end-of-file-fixer - id: mixed-line-ending - repo: https://github.com/psf/black - rev: "22.3.0" + rev: "23.11.0" hooks: - - id: black - - repo: https://github.com/tomcatling/black-nb - rev: "0.7" - hooks: - - id: black-nb - description: strip output and black source + - id: black additional_dependencies: ['black[jupyter]'] - args: ["--clear-output"] + - repo: https://github.com/kynan/nbstripout + rev: "0.6.0" + hooks: + - id: nbstripout diff --git a/exmol/exmol.py b/exmol/exmol.py index a1279048..eabe072a 100644 --- a/exmol/exmol.py +++ b/exmol/exmol.py @@ -1,4 +1,4 @@ -from functools import reduce +from functools import reduce, lru_cache import inspect from typing import * import io @@ -9,6 +9,7 @@ from matplotlib.patches import Rectangle, FancyBboxPatch # type: ignore from matplotlib.offsetbox import AnnotationBbox # type: ignore import matplotlib as mpl # type: ignore +import re import selfies as sf # type: ignore import tqdm # type: ignore import textwrap # type: ignore @@ -367,6 +368,38 @@ def get_basic_alphabet() -> Set[str]: return a +# The code below checks for accidental addition of symbols outside of the alphabet. +# the only way this can happen is if the character indicating a ring +# length is mutated to appear eleswhere. The ring length symbol +# is always a plain uncharged element. + + +def _alphabet_to_elements(alphabet: List[str]) -> Set[str]: + """Converts SELFIES alphabet to element symbols""" + symbols = [] + for s in alphabet: + s = s.replace("[", "").replace("]", "") + if s.isalpha(): + symbols.append(s) + return set(symbols) + + +def _check_alphabet_consistency( + smiles: str, alphabet_symbols: Set[str], check=False +) -> True: + """Checks if SMILES only contains tokens from alphabet""" + + alphabet_symbols = _alphabet_to_elements(set(alphabet_symbols)) + # find all elements in smiles (Upper alpha or upper alpha followed by lower alpha) + smiles_symbols = set(re.findall(r"[A-Z][a-z]?", smiles)) + if check and not smiles_symbols.issubset(alphabet_symbols): + # show which symbols are not in alphabet + raise ValueError( + "symbols not in alphabet" + smiles_symbols.difference(alphabet_symbols) + ) + return smiles_symbols.issubset(alphabet_symbols) + + def run_stoned( start_smiles: str, fp_type: str = "ECFP4", @@ -392,6 +425,9 @@ def run_stoned( alphabet = get_basic_alphabet() if type(alphabet) == set: alphabet = list(alphabet) + alphabet_symbols = _alphabet_to_elements(alphabet) + # make sure starting smiles is consistent with alphabet + _ = _check_alphabet_consistency(start_smiles, alphabet_symbols, check=True) num_mutation_ls = list(range(min_mutations, max_mutations + 1)) start_mol = smi2mol(start_smiles) @@ -418,6 +454,16 @@ def run_stoned( ) # Convert back to SMILES: smiles_back = [sf.decoder(x) for x in selfies_mut] + # check if smiles are consistent with alphabet and downslect + selfies_mut, smiles_back = zip( + *[ + (s, sm) + for s, sm in zip(selfies_mut, smiles_back) + if _check_alphabet_consistency(sm, alphabet_symbols) + ] + ) + selfies_mut, smiles_back = list(selfies_mut), list(smiles_back) + all_smiles_collect = all_smiles_collect + smiles_back all_selfies_collect = all_selfies_collect + selfies_mut if _pbar: diff --git a/exmol/stoned/stoned.py b/exmol/stoned/stoned.py index 8c10e2c0..27198aba 100644 --- a/exmol/stoned/stoned.py +++ b/exmol/stoned/stoned.py @@ -446,7 +446,6 @@ def get_mutated_SELFIES(selfies_ls, num_mutations, alphabet): for _ in range(num_mutations): selfie_ls_mut_ls = [] for str_ in selfies_ls: - str_chars = get_selfie_chars(str_) max_molecules_len = len(str_chars) + num_mutations diff --git a/exmol/version.py b/exmol/version.py index ea0e5f86..36ab2067 100644 --- a/exmol/version.py +++ b/exmol/version.py @@ -1 +1 @@ -__version__ = "3.0.3" +__version__ = "3.0.4" diff --git a/paper2_LIME/RF-lime.ipynb b/paper2_LIME/RF-lime.ipynb index b3cca5c5..e14a2fa3 100644 --- a/paper2_LIME/RF-lime.ipynb +++ b/paper2_LIME/RF-lime.ipynb @@ -63,9 +63,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "# make object that can compute descriptors\n", @@ -249,9 +247,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "x = wls_attr.keys()\n", diff --git a/paper2_LIME/Solubility-RNN.ipynb b/paper2_LIME/Solubility-RNN.ipynb index fa4f87e2..450e2355 100644 --- a/paper2_LIME/Solubility-RNN.ipynb +++ b/paper2_LIME/Solubility-RNN.ipynb @@ -104,9 +104,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "smiles = list(soldata[\"SMILES\"])\n", @@ -234,9 +232,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "# now get sequences\n", @@ -270,9 +266,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "model = tf.keras.Sequential()\n", @@ -299,9 +293,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "model.compile(tf.optimizers.Adam(1e-3), loss=\"mean_squared_error\")\n", @@ -411,9 +403,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "# Make sure SMILES doesn't contain multiple fragments\n", @@ -434,9 +424,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "from IPython.display import display, SVG\n", @@ -509,9 +497,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "# Inspect space\n", @@ -533,9 +519,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "fkw = {\"figsize\": (6, 4)}\n", @@ -686,9 +670,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "# Get subsets and calculate lime importances - subsample - get rank correlation\n", @@ -747,9 +729,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "# Mutation\n", @@ -773,9 +753,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "# Alphabet\n", diff --git a/paper2_LIME/Tutorial.ipynb b/paper2_LIME/Tutorial.ipynb index e73e9d63..7a8cddb5 100644 --- a/paper2_LIME/Tutorial.ipynb +++ b/paper2_LIME/Tutorial.ipynb @@ -145,9 +145,7 @@ "cell_type": "code", "execution_count": null, "id": "ed252205", - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "exmol.lime_explain(space)\n", diff --git a/paper3_Scents/GNNModelTrainingAndEvaluation.ipynb b/paper3_Scents/GNNModelTrainingAndEvaluation.ipynb index 062bbd71..50793712 100644 --- a/paper3_Scents/GNNModelTrainingAndEvaluation.ipynb +++ b/paper3_Scents/GNNModelTrainingAndEvaluation.ipynb @@ -21,8 +21,7 @@ "base_uri": "https://localhost:8080/" }, "id": "gZeyyFPJbYxi", - "outputId": "2ac9b9a9-ca8b-4608-d3a4-a2c8f937f71f", - "scrolled": true + "outputId": "2ac9b9a9-ca8b-4608-d3a4-a2c8f937f71f" }, "outputs": [], "source": [ @@ -562,9 +561,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "# Train model\n", @@ -1041,8 +1038,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "BBgB_3tPbYxx", - "scrolled": true + "id": "BBgB_3tPbYxx" }, "outputs": [], "source": [ diff --git a/requirements.txt b/requirements.txt index c4363f6b..0edc5d0c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,4 @@ pre-commit mypy pytest pytest-cov -click openai diff --git a/setup.py b/setup.py index 5ac96dd1..2333f616 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ "skunk >= 0.4.0", "importlib-resources", "synspace", - "langchain", + "langchain==0.0.343", ], test_suite="tests", long_description=long_description, diff --git a/tests/test_exmol.py b/tests/test_exmol.py index 3269e13b..9c9bf236 100644 --- a/tests/test_exmol.py +++ b/tests/test_exmol.py @@ -45,7 +45,6 @@ def test_sanitize_smiles_chiral(): assert "@" in result[1] -# TODO let STONED people write these when they finish their repo def test_run_stoned(): result = exmol.run_stoned( "N#CC=CC(C(=O)NCC1=CC=CC=C1C(=O)N)(C)CC2=CC=C(F)C=C2CC", @@ -199,14 +198,27 @@ def test_run_custom(): def test_run_stones_alphabet(): result = exmol.run_stoned( - "N#CC=CC(C(=O)NCC1=CC=CC=C1C(=O)N)(C)CC2=CC=C(F)C=C2CC", - num_samples=10, - max_mutations=1, + "C1=CC=C(C=C1)C2=CC=CC=C2", + num_samples=25, + max_mutations=3, alphabet=["[C]", "[O]"], ) # Can get duplicates assert len(result[0]) >= 0 + # check no other characters + for mol in result[0]: + print(["C", "O", "#", "(", ")"] + list([str(i) for i in range(10)])) + print(mol) + assert not any( + [ + c + not in ["C", "O", "#", "(", ")", "="] + + list([str(i) for i in range(10)]) + for c in mol + ] + ) + def test_sample(): def model(s, se):