From b4446328f2ff4ee9d43335a40e931f426e244231 Mon Sep 17 00:00:00 2001 From: maclandrol Date: Wed, 17 Jan 2024 13:05:35 -0500 Subject: [PATCH 1/6] add support for input atom/bond color in addition to lasso --- datamol/viz/_lasso_highlight.py | 51 ++++++++++++++++++++++++++++++--- datamol/viz/utils.py | 4 +++ 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/datamol/viz/_lasso_highlight.py b/datamol/viz/_lasso_highlight.py index 29b87d10..854607df 100644 --- a/datamol/viz/_lasso_highlight.py +++ b/datamol/viz/_lasso_highlight.py @@ -6,7 +6,7 @@ # - possibility to do this for multiple target molecules at once # - have the option to write to a file like to_image -from typing import List, Iterator, Tuple, Union, Optional, Any, cast +from typing import List, Dict, Iterator, Tuple, Union, Optional, Any, cast from collections import defaultdict from collections import namedtuple @@ -400,6 +400,10 @@ def lasso_highlight_image( line_width: int = 2, scale_padding: float = 1.0, verbose: bool = False, + highlight_atoms: Optional[List[List[int]]] = None, + highlight_bonds: Optional[List[List[int]]] = None, + highlight_atom_colors: Optional[List[Dict[int, DatamolColor]]] = None, + highlight_bond_colors: Optional[List[Dict[int, DatamolColor]]] = None, **kwargs: Any, ): """Create an image of a list of molecules with substructure matches using lasso-based highlighting. @@ -421,6 +425,10 @@ def lasso_highlight_image( line_width: width of drawn lines. scale_padding: Padding around the molecule when drawing to scale. verbose: Whether to print the verbose information. + highlight_atoms: The atoms to highlight, a list for each molecule. + highlight_bonds: The bonds to highlight, a list for each molecule. + highlight_atom_colors: The colors to use for highlighting atoms, a list of dict mapping atom index to color for each molecule. + highlight_bond_colors: The colors to use for highlighting bonds, a list of dict mapping bond index to color for each molecule. **kwargs: Additional arguments to pass to the drawing function. See RDKit documentation related to `MolDrawOptions` for more details at https://www.rdkit.org/docs/source/rdkit.Chem.Draw.rdMolDraw2D.html. @@ -551,9 +559,38 @@ def lasso_highlight_image( # EN: the following is edge-case free after trying 6 different logics, but may break if RDKit changes the way it draws molecules scaling_val = Point2D(scale_padding, scale_padding) + if isinstance(highlight_atoms, list) and isinstance(highlight_atoms[0], int): + highlight_atoms = [highlight_atoms] * len(target_molecules) + if isinstance(highlight_bonds, list) and isinstance(highlight_bonds[0], int): + highlight_bonds = [highlight_bonds] * len(target_molecules) + if isinstance(highlight_atom_colors, dict): + highlight_atom_colors = [highlight_atom_colors] * len(target_molecules) + if isinstance(highlight_bond_colors, dict): + highlight_bond_colors = [highlight_bond_colors] * len(target_molecules) + + # make sure we are using rdkit colors + print(highlight_atom_colors[0][4], to_rdkit_color(highlight_atom_colors[0][4])) + highlight_atom_colors = [ + {k: to_rdkit_color(v) for k, v in _.items()} for _ in highlight_atom_colors + ] + highlight_bond_colors = [ + {k: to_rdkit_color(v) for k, v in _.items()} for _ in highlight_bond_colors + ] + + kwargs["highlightAtoms"] = highlight_atoms + kwargs["highlightBonds"] = highlight_bonds + kwargs["highlightAtomColors"] = highlight_atom_colors + kwargs["highlightBondColors"] = highlight_bond_colors + + print(kwargs) try: - drawer.DrawMolecules(mols_to_draw, legends=legends, **kwargs) - except Exception: + drawer.DrawMolecules( + mols_to_draw, + legends=legends, + **kwargs, + ) + except Exception as e: + logger.error(e) raise ValueError( "Failed to draw molecules. Some arguments neither match expected MolDrawOptions, nor DrawMolecule inputs. Please check the input arguments." ) @@ -567,8 +604,14 @@ def lasso_highlight_image( h_pos, w_pos = np.unravel_index(ind, (n_rows, n_cols)) offset_x = int(w_pos * mol_size[0]) offset_y = int(h_pos * mol_size[1]) + + ind_kwargs = kwargs.copy() + ind_kwargs["highlightAtoms"] = kwargs["highlightAtoms"][ind] + ind_kwargs["highlightAtomColors"] = kwargs["highlightAtomColors"][ind] + ind_kwargs["highlightBonds"] = kwargs["highlightBonds"][ind] + ind_kwargs["highlightBondColors"] = kwargs["highlightBondColors"][ind] drawer.SetOffset(offset_x, offset_y) - drawer.DrawMolecule(mol, legend=legends[ind], **kwargs) + drawer.DrawMolecule(mol, legend=legends[ind], **ind_kwargs) offset = None if draw_mols_same_scale: offset = drawer.Offset() diff --git a/datamol/viz/utils.py b/datamol/viz/utils.py index c2d8cdb2..67522234 100644 --- a/datamol/viz/utils.py +++ b/datamol/viz/utils.py @@ -143,4 +143,8 @@ def to_rdkit_color(color: Optional[DatamolColor]) -> Optional[RDKitColor]: """ if isinstance(color, str): return mcolors.to_rgba(color) # type: ignore + + if len(color) in [3, 4] and any(x > 1 for x in color): + return tuple(x / 255 if i < 3 else x for i, x in enumerate(color)) + return color From 03a247369c347afca2ad7af6417eedb30404c171 Mon Sep 17 00:00:00 2001 From: maclandrol Date: Wed, 17 Jan 2024 13:06:19 -0500 Subject: [PATCH 2/6] fix test --- tests/test_viz_lasso_highlight.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_viz_lasso_highlight.py b/tests/test_viz_lasso_highlight.py index 542c9464..a56ba6be 100644 --- a/tests/test_viz_lasso_highlight.py +++ b/tests/test_viz_lasso_highlight.py @@ -17,6 +17,30 @@ def test_from_mol(): assert dm.lasso_highlight_image(mol, smarts_list) +def test_with_highlight(): + smi = "CO[C@@H](O)C1=C(O[C@H](F)Cl)C(C#N)=C1ONNC[NH3+]" + mol = dm.to_mol(smi) + smarts_list = "CONN" + highlight_atoms = [4, 5, 6] + highlight_bonds = [1, 2, 3, 4] + highlight_atom_colors = {4: (230, 230, 250), 5: (230, 230, 250), 6: (230, 230, 250)} + highlight_bond_colors = { + 1: (230, 230, 250), + 2: (230, 230, 250), + 3: (230, 230, 250), + 4: (230, 230, 250), + } + assert dm.lasso_highlight_image( + mol, + smarts_list, + highlight_atoms=highlight_atoms, + highlight_bonds=highlight_bonds, + highlight_atom_colors=highlight_atom_colors, + highlight_bond_colors=highlight_bond_colors, + continuousHighlight=False, + ) + + def test_original_working_solution_list_single_str(): smi = "CO[C@@H](O)C1=C(O[C@H](F)Cl)C(C#N)=C1ONNC[NH3+]" smarts_list = ["CONN"] From f6f56254479f654502c54f19900fd15b2f481033 Mon Sep 17 00:00:00 2001 From: maclandrol Date: Wed, 17 Jan 2024 13:11:09 -0500 Subject: [PATCH 3/6] remove prints --- datamol/viz/_lasso_highlight.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/datamol/viz/_lasso_highlight.py b/datamol/viz/_lasso_highlight.py index 854607df..5206f634 100644 --- a/datamol/viz/_lasso_highlight.py +++ b/datamol/viz/_lasso_highlight.py @@ -569,7 +569,6 @@ def lasso_highlight_image( highlight_bond_colors = [highlight_bond_colors] * len(target_molecules) # make sure we are using rdkit colors - print(highlight_atom_colors[0][4], to_rdkit_color(highlight_atom_colors[0][4])) highlight_atom_colors = [ {k: to_rdkit_color(v) for k, v in _.items()} for _ in highlight_atom_colors ] @@ -582,7 +581,6 @@ def lasso_highlight_image( kwargs["highlightAtomColors"] = highlight_atom_colors kwargs["highlightBondColors"] = highlight_bond_colors - print(kwargs) try: drawer.DrawMolecules( mols_to_draw, From 854643ab39efd9273957e3afcd2c9bdf3c8d126b Mon Sep 17 00:00:00 2001 From: maclandrol Date: Wed, 17 Jan 2024 13:37:03 -0500 Subject: [PATCH 4/6] edge case --- datamol/viz/_lasso_highlight.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/datamol/viz/_lasso_highlight.py b/datamol/viz/_lasso_highlight.py index 5206f634..05474acf 100644 --- a/datamol/viz/_lasso_highlight.py +++ b/datamol/viz/_lasso_highlight.py @@ -569,12 +569,14 @@ def lasso_highlight_image( highlight_bond_colors = [highlight_bond_colors] * len(target_molecules) # make sure we are using rdkit colors - highlight_atom_colors = [ - {k: to_rdkit_color(v) for k, v in _.items()} for _ in highlight_atom_colors - ] - highlight_bond_colors = [ - {k: to_rdkit_color(v) for k, v in _.items()} for _ in highlight_bond_colors - ] + if highlight_atom_colors is not None: + highlight_atom_colors = [ + {k: to_rdkit_color(v) for k, v in _.items()} for _ in highlight_atom_colors + ] + if highlight_bond_colors is not None: + highlight_bond_colors = [ + {k: to_rdkit_color(v) for k, v in _.items()} for _ in highlight_bond_colors + ] kwargs["highlightAtoms"] = highlight_atoms kwargs["highlightBonds"] = highlight_bonds @@ -604,10 +606,14 @@ def lasso_highlight_image( offset_y = int(h_pos * mol_size[1]) ind_kwargs = kwargs.copy() - ind_kwargs["highlightAtoms"] = kwargs["highlightAtoms"][ind] - ind_kwargs["highlightAtomColors"] = kwargs["highlightAtomColors"][ind] - ind_kwargs["highlightBonds"] = kwargs["highlightBonds"][ind] - ind_kwargs["highlightBondColors"] = kwargs["highlightBondColors"][ind] + if isinstance(ind_kwargs["highlightAtoms"], list): + ind_kwargs["highlightAtoms"] = ind_kwargs["highlightAtoms"][ind] + if isinstance(ind_kwargs["highlightAtomColors"], list): + ind_kwargs["highlightAtomColors"] = ind_kwargs["highlightAtomColors"][ind] + if isinstance(ind_kwargs["highlightBonds"], list): + ind_kwargs["highlightBonds"] = ind_kwargs["highlightBonds"][ind] + if isinstance(ind_kwargs["highlightBondColors"], list): + ind_kwargs["highlightBondColors"] = ind_kwargs["highlightBondColors"][ind] drawer.SetOffset(offset_x, offset_y) drawer.DrawMolecule(mol, legend=legends[ind], **ind_kwargs) offset = None From 4f94acf09f3f87207aefef37917ebc9a911276ad Mon Sep 17 00:00:00 2001 From: maclandrol Date: Wed, 17 Jan 2024 13:47:37 -0500 Subject: [PATCH 5/6] color edge cases --- datamol/viz/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/datamol/viz/utils.py b/datamol/viz/utils.py index 67522234..c0e6ad75 100644 --- a/datamol/viz/utils.py +++ b/datamol/viz/utils.py @@ -141,10 +141,12 @@ def to_rdkit_color(color: Optional[DatamolColor]) -> Optional[RDKitColor]: Args: color: A datamol color: hex, rgb, rgba or None. """ + if color is None: + return None + if isinstance(color, str): return mcolors.to_rgba(color) # type: ignore - - if len(color) in [3, 4] and any(x > 1 for x in color): + if isinstance(color, (tuple, list)) and len(color) in [3, 4] and any(x > 1 for x in color): return tuple(x / 255 if i < 3 else x for i, x in enumerate(color)) return color From e56d08362f7e283eaa5a17421f43f597bba79066 Mon Sep 17 00:00:00 2001 From: maclandrol Date: Fri, 19 Jan 2024 09:31:11 -0500 Subject: [PATCH 6/6] improve docs --- datamol/viz/_lasso_highlight.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datamol/viz/_lasso_highlight.py b/datamol/viz/_lasso_highlight.py index 05474acf..7298d54f 100644 --- a/datamol/viz/_lasso_highlight.py +++ b/datamol/viz/_lasso_highlight.py @@ -412,7 +412,7 @@ def lasso_highlight_image( Args: target_molecules: One or a list of molecules to be highlighted. search_molecules: The substructure to be highlighted. - atom_indices: Atom indices to be highlighted substructure. + atom_indices: Atom indices to be highlighted as substructure using the lasso visualization. legends: A string or a list of string as legend for every molecules. n_cols: Number of molecules per column. mol_size: The size of the image to be returned @@ -425,8 +425,8 @@ def lasso_highlight_image( line_width: width of drawn lines. scale_padding: Padding around the molecule when drawing to scale. verbose: Whether to print the verbose information. - highlight_atoms: The atoms to highlight, a list for each molecule. - highlight_bonds: The bonds to highlight, a list for each molecule. + highlight_atoms: The atoms to highlight, a list for each molecule. It's the `highlightAtoms` argument of the RDKit drawer object. + highlight_bonds: The bonds to highlight, a list for each molecule. It's the `highlightBonds` argument of the RDKit drawer object. highlight_atom_colors: The colors to use for highlighting atoms, a list of dict mapping atom index to color for each molecule. highlight_bond_colors: The colors to use for highlighting bonds, a list of dict mapping bond index to color for each molecule. **kwargs: Additional arguments to pass to the drawing function. See RDKit