diff --git a/datamol/viz/_viz.py b/datamol/viz/_viz.py index b340d1f5..1b6f2eb5 100644 --- a/datamol/viz/_viz.py +++ b/datamol/viz/_viz.py @@ -3,6 +3,7 @@ from typing import Tuple from typing import Optional from typing import Any +from IPython.core.getipython import get_ipython from rdkit.Chem import Draw @@ -120,18 +121,27 @@ def to_image( else: _kwargs[k] = v - image = Draw.MolsToGridImage( - mols, - legends=legends, - molsPerRow=n_cols, - useSVG=use_svg, - subImgSize=mol_size, - highlightAtomLists=_highlight_atom, - highlightBondLists=_highlight_bond, - drawOptions=draw_options, - maxMols=max_mols, + # Check if we are in a Jupyter notebook or IPython display context + in_notebook = get_ipython() is not None + + # Create a dictionary of arguments for the MolsToGridImage function + draw_args = { + "mols": mols, + "legends": legends, + "molsPerRow": n_cols, + "useSVG": use_svg, + "subImgSize": mol_size, + "highlightAtomLists": _highlight_atom, + "highlightBondLists": _highlight_bond, + "drawOptions": draw_options, **_kwargs, - ) + } + + # Conditionally add the maxMols argument if in a notebook + if in_notebook: + draw_args["maxMols"] = max_mols + + image = Draw.MolsToGridImage(**draw_args) if outfile is not None: image_to_file(image, outfile, as_svg=use_svg)