Skip to content

Commit

Permalink
adding tests to make sure smiles atom graphs are generated correctly.
Browse files Browse the repository at this point in the history
  • Loading branch information
InnocentBug committed Jan 31, 2025
1 parent bef2ca1 commit 8035bf0
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 37 deletions.
173 changes: 141 additions & 32 deletions src/gbigsmiles/atom_graph.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import copy
import json

Check failure on line 2 in src/gbigsmiles/atom_graph.py

View workflow job for this annotation

GitHub Actions / Trunk Check

flake8(F401)

[new] 'json' imported but unused

Check failure on line 2 in src/gbigsmiles/atom_graph.py

View workflow job for this annotation

GitHub Actions / Trunk Check

ruff(F401)

[new] `json` imported but unused
import warnings
from dataclasses import dataclass

Check failure on line 4 in src/gbigsmiles/atom_graph.py

View workflow job for this annotation

GitHub Actions / Trunk Check

flake8(F401)

[new] 'dataclasses.dataclass' imported but unused

Check failure on line 4 in src/gbigsmiles/atom_graph.py

View workflow job for this annotation

GitHub Actions / Trunk Check

ruff(F401)

[new] `dataclasses.dataclass` imported but unused
from typing import Any

import networkx as nx
import numpy as np

from .chem_resource import atomic_masses
from .distribution import StochasticDistribution
from .exception import InvalidGenerationSource, UnvalidatedGenerationSource

Check failure on line 12 in src/gbigsmiles/atom_graph.py

View workflow job for this annotation

GitHub Actions / Trunk Check

cspell(error)

[new] Unknown word (Unvalidated) Suggestions: [Invalidate, Unanimated, Unalienated, Unevaluated, Unpaginated]
from .generating_graph import (

Check failure on line 13 in src/gbigsmiles/atom_graph.py

View workflow job for this annotation

GitHub Actions / Trunk Check

flake8(F401)

[new] '.generating_graph._STATIC_NAME' imported but unused
_AROMATIC_NAME,
Expand All @@ -15,12 +19,53 @@
from .util import get_global_rng


@dataclass(frozen=True)
class _HalfAtomBond:
atom_idx: int
gen_bond_attr: str
target_gen_idx: str
bond_type: str
def __init__(self, atom_idx: int, node_idx: str, graph, bond_attr_converter):
self.atom_idx: int = atom_idx
self.node_idx: str = node_idx
self.weight: float = graph.nodes[node_idx]["gen_weight"]

self._mode_attr_map = {}
self._mode_target_map = {}
self._mode_weight_map = {}

for u, v, d in graph.out_edges(node_idx, data=True):

Check failure on line 32 in src/gbigsmiles/atom_graph.py

View workflow job for this annotation

GitHub Actions / Trunk Check

ruff(B007)

[new] Loop control variable `u` not used within loop body
if not d["static"]:
for k in _NON_STATIC_ATTR:
if d[k] > 0:
try:
self._mode_attr_map[k] += [bond_attr_converter(d)]
except KeyError:
self._mode_attr_map[k] = [bond_attr_converter(d)]
try:
self._mode_target_map[k] += [v]
except KeyError:
self._mode_target_map[k] = [v]
try:
self._mode_weight_map[k] += [d[k]]
except KeyError:
self._mode_weight_map[k] = [d[k]]

def has_any_bonds(self):
has_bonds = False
for key in self._mode_attr_map:
if len(self._mode_attr_map[key]) > 0:
has_bonds = True
return has_bonds

def has_mode_bonds(self, mode):
if mode not in self._mode_attr_map:
return False
return len(self._mode_attr_map[mode]) > 0

def get_mode_bonds(self, mode):
try:
return self._mode_attr_map[mode], self._mode_target_map[mode]
except KeyError:
return [], []

def __str__(self):
return f"HalfAtomBond({self.atom_idx}, {self.node_idx}, {self.weight}, {self._mode_attr_map}, {self._mode_target_map}, {self._mode_weight_map})"


class _PartialAtomGraph:
Expand All @@ -31,48 +76,87 @@ def __init__(self, generating_graph, static_graph, source_node):
self._atom_id = 0
self.generating_graph = generating_graph
self.static_graph = static_graph
self._stochastic_vector: list[float] = generating_graph.nodes[source_node][
"stochastic_generation"
]
self._stochastic_generation: None | StochasticDistribution = (
StochasticDistribution.from_serial_vector(self._stochastic_vector)
)

self.atom_graph = nx.Graph()
self._open_half_bonds = []
self._mol_weight: float = 0.0

self.add_static_sub_graph(source_node)

def merge(self, other, self_idx, other_idx, bond_attr):
# relabel other idx
remapping_dict = {idx: idx + self._atom_id for idx in other.atom_graph.nodes}
other_graph = nx.relabel_nodes(other.atom_graph, remapping_dict, copy=True)
other_open_half_bonds = []
for half_bond in other._open_half_bonds:
new_half_bond = copy.copy(half_bond)
new_half_bond.atom_idx += self.atom_id
other_open_half_bonds += [new_half_bond]
other_idx += self._atom_id

# Now we can do the actual merging
self._atom_id += other._atom_id
self._stochastic_vector = other.stochastic_vector
self._stochastic_generation = copy.copy(other._stochastic_generation)

self.atom_graph = nx.union(self.atom_graph, other_graph)
self.atom_graph.add_edge(self_idx, other_idx, **bond_attr)
self._open_half_bonds += other_open_half_bonds
self._mol_weight += other._mol_weight

def add_static_sub_graph(self, source):
atom_key_to_gen_key = {}
gen_key_to_atom_key = {}

def add_node(node_idx):
transistion_edges = []
termination_edges = []
stochastic_edges = []

data = self.gen_node_attr_to_atom_attr(self.generating_graph.nodes[source])
self.atom_graph.add_node(self._atom_id, **data)
atom_key_to_gen_key[self._atom_id] = source
gen_key_to_atom_key[source] = self._atom_id

for u, v, d in self.generating_graph.out_edges(node_idx, data=True):
if not d["static"]:
for k in _NON_STATIC_ATTR:
if d[k] > 0:
stochastic_edges.append(
_HalfAtomBond(
self._atom_id,
self.gen_edge_attr_to_bond_attr(d),
v,
_STOCHASTIC_NAME,
)
)

data = self.gen_node_attr_to_atom_attr(self.generating_graph.nodes[node_idx])
self.atom_graph.add_node(self._atom_id, **(data | {"origin_idx": node_idx}))
atom_key_to_gen_key[self._atom_id] = node_idx
gen_key_to_atom_key[node_idx] = self._atom_id
half_bond = _HalfAtomBond(
self._atom_id, node_idx, self.generating_graph, self.gen_edge_attr_to_bond_attr
)

self._mol_weight += atomic_masses[data["atomic_num"]]
self._atom_id += 1

return half_bond

half_bond = add_node(source)
if half_bond.weight > 0 and half_bond.has_any_bonds():
self._open_half_bonds += [half_bond]

edges_data_map = {}

for u, v, k in nx.edge_dfs(self.static_graph, source=source):
if (u, v) not in edges_data_map and (v, u) not in edges_data_map:

edges_data_map[(u, v)] = self.gen_edge_attr_to_bond_attr(
for gen_atom_idx in (u, v):
if gen_atom_idx not in gen_key_to_atom_key:
half_bond = add_node(gen_atom_idx)
if half_bond.weight > 0 and half_bond.has_any_bonds():
self._open_half_bonds += [half_bond]

u_atom_idx = gen_key_to_atom_key[u]
v_atom_idx = gen_key_to_atom_key[v]

if (u_atom_idx, v_atom_idx) not in edges_data_map and (
v_atom_idx,
u_atom_idx,
) not in edges_data_map:
edges_data_map[(u_atom_idx, v_atom_idx)] = self.gen_edge_attr_to_bond_attr(
self.static_graph.get_edge_data(u, v, k)
)
print(edges_data_map)

for u_atom_idx, v_atom_idx in edges_data_map:
self.atom_graph.add_edge(
u_atom_idx, v_atom_idx, **edges_data_map[(u_atom_idx, v_atom_idx)]
)

def gen_node_attr_to_atom_attr(
self, attr: dict[str, bool | float | int], keys_to_copy: None | set[str] = None
Expand All @@ -95,6 +179,23 @@ def _copy_some_dict_attr(dictionary: dict[str, Any], keys_to_copy: set[str]) ->
new_dict[k] = dictionary[k]
return new_dict

@property
def stochastic_vector(self):
return self._stochastic_vector.copy()

def draw_mw(self, rng=None) -> None | float:
if self._stochastic_generation is not None:
return self._stochastic_generation.draw_mw(rng)
return -1.0

@property
def molw(self):
return self._mol_weight


class _MolWeightTracker:
pass


class AtomGraph:
def __init__(self, ml_graph):
Expand Down Expand Up @@ -139,7 +240,7 @@ def ml_graph(self):
def _get_random_start_node(self, rng):
return rng.choice(self._starting_node_idx, p=self._starting_node_weight)

def sample_graph(self, source: str = None, rng=None):
def sample_mol_graph(self, source: str = None, rng=None):

if rng is None:
rng = get_global_rng()
Expand All @@ -155,4 +256,12 @@ def sample_graph(self, source: str = None, rng=None):
UnvalidatedGenerationSource(source, self._starting_node_idx, self.ml_graph)

Check failure on line 256 in src/gbigsmiles/atom_graph.py

View workflow job for this annotation

GitHub Actions / Trunk Check

cspell(error)

[new] Unknown word (Unvalidated) Suggestions: [Invalidate, Unanimated, Unalienated, Unevaluated, Unpaginated]
)

atom_graph = _PartialAtomGraph(self.ml_graph, self._static_graph, source)
partial_atom_graph = _PartialAtomGraph(self.ml_graph, self._static_graph, source)
while len(partial_atom_graph._open_half_bonds) > 0:
print(
partial_atom_graph.atom_graph,
partial_atom_graph.molw,
[str(ob) for ob in partial_atom_graph._open_half_bonds],
)

return partial_atom_graph.atom_graph
9 changes: 8 additions & 1 deletion src/gbigsmiles/big_smiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ def _generate_partial_graph(self) -> _PartialGeneratingGraph:
] = self.total_molecular_weight
return partial_graph

@property
def num_mol_species(self):
print(self._children)
return len(self._children)


class DotGeneration(_AbstractIterativeGenerativeClass):
def __init__(self, children):
Expand All @@ -142,7 +147,9 @@ def __init__(self, children):

@property
def molecular_weight(self):
return self._dot_system_size.molecular_weight
if self._dot_system_size is not None:
return self._dot_system_size.molecular_weight
return 0.0


class DotSystemSize(BigSMILESbase, GenerationBase):
Expand Down
8 changes: 8 additions & 0 deletions src/gbigsmiles/chem_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@
}

atom_name_num = dict((v, k) for k, v in atom_name_mapping.items())
# Add aromatic versions
atom_name_num["c"] = atom_name_num["C"]
atom_name_num["b"] = atom_name_num["B"]
atom_name_num["n"] = atom_name_num["N"]
atom_name_num["o"] = atom_name_num["O"]
atom_name_num["s"] = atom_name_num["S"]
atom_name_num["p"] = atom_name_num["P"]


atom_color_mapping = {
-1: "FFFFFF",
Expand Down
3 changes: 1 addition & 2 deletions src/gbigsmiles/data/g-bigsmiles.lark
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ branched_atom: _atom_stand_in ring_bond* branch*
branch: "(" bond_symbol? (smiles | big_smiles_molecule)* ")"
atom_assembly: bond_symbol? branched_atom
smiles: branched_atom atom_assembly*
| smiles dot smiles


isotope: INT
Expand Down Expand Up @@ -110,7 +109,7 @@ atom_class: ":" INT


dot_system_size: "|" WS_INLINE* NUMBER WS_INLINE* "|"
dot_generation: dot dot_system_size
dot_generation: dot dot_system_size?
dot: "."

_number_list_repeat.-1: WS_INLINE+ NUMBER
Expand Down
59 changes: 57 additions & 2 deletions tests/test_smiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def test_smiles_parsing(chembl_smi_list):
for smi in chembl_smi_list:
if len(smi) > 0:
smiles_instance = gbigsmiles.Smiles.make(smi)
smiles_instance = gbigsmiles.BigSmiles.make(smi)
assert smi == smiles_instance.generate_string(True)


Expand All @@ -36,7 +36,7 @@ def test_smiles_weight(n, chembl_smi_list):
assert abs(total_mw - big_smiles.total_molecular_weight) < 1e-6


def test_smiles_graph(chembl_smi_list):
def test_smiles_gen_graph(chembl_smi_list):

def node_match(gb_node, pysmi_node):
return_value = True
Expand Down Expand Up @@ -98,3 +98,58 @@ def edge_match(gb_edge, pysmi_edge):

print("\n", smi, pysmiles_graph, graph, "\n")
assert nx.is_isomorphic(graph, pysmiles_graph, node_match=node_match, edge_match=edge_match)


def test_smiles_atom_graph(chembl_smi_list):

def node_match(gb_node, pysmi_node):
return_value = True
if gb_node["atomic_num"] != gbigsmiles.chem_resource.atom_name_num[pysmi_node["element"]]:
return_value = False
if gb_node["charge"] != pysmi_node["charge"]:
return_value = False
if gb_node["aromatic"] != pysmi_node["aromatic"]:
return_value = False

return return_value

def edge_match(gb_edge, pysmi_edge):
return_value = False
if gb_edge["aromatic"] and pysmi_edge["order"] == 1.5:
return_value = True

if gb_edge["bond_type"] == pysmi_edge["order"]:
return_value = True

return return_value

for smi in chembl_smi_list:
big_smiles = gbigsmiles.BigSmiles.make(smi)

if big_smiles.num_mol_species == 1:
pysmiles_graph = pysmiles.read_smiles(smi, reinterpret_aromatic=False)

# PYSMILES and us treat hydrogen differently
for node in list(pysmiles_graph.nodes(data=True)):
if "H" in node[1]["element"]:
pysmiles_graph.remove_node(node[0])

# Pysmiles adds 0 order bonds to the graph, but we do not.
for edge in list(pysmiles_graph.edges(data=True)):
if edge[2]["order"] == 0:
pysmiles_graph.remove_edge(edge[0], edge[1])

gen_graph = big_smiles.get_generating_graph()
atom_graph = gen_graph.get_atom_graph()
mol_graph = atom_graph.sample_mol_graph()
graph = mol_graph

# Remove hydrogens from bigsmiles graph for comparison:
for node in list(graph.nodes(data=True)):
if node[1]["atomic_num"] == 1:
graph.remove_node(node[0])

print("\n", smi, pysmiles_graph, graph, "\n")
assert nx.is_isomorphic(
graph, pysmiles_graph, node_match=node_match, edge_match=edge_match
)

0 comments on commit 8035bf0

Please sign in to comment.