Skip to content

Commit f165212

Browse files
davidsbatistalbux
andauthored
feat: Add support for custom (or offline) Mermaid.ink server and support all parameters (#8799)
* compress graph data to support pako endpoint * support mermaid.ink parameters and custom servers * dont try to resolve conflicts with the github web ui... * avoid double graph copy * fixing typing, improving docstrings and release notes * reverting type * nit - force type checker no cache * nit - force type checker no cache --------- Co-authored-by: Ulises M <[email protected]> Co-authored-by: Ulises M <[email protected]>
1 parent 503d275 commit f165212

File tree

4 files changed

+286
-27
lines changed

4 files changed

+286
-27
lines changed

Diff for: haystack/core/pipeline/base.py

+54-9
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
DEFAULT_MARSHALLER = YamlMarshaller()
3636

37-
# We use a generic type to annotate the return value of classmethods,
37+
# We use a generic type to annotate the return value of class methods,
3838
# so that static analyzers won't be confused when derived classes
3939
# use those methods.
4040
T = TypeVar("T", bound="PipelineBase")
@@ -619,31 +619,76 @@ def outputs(self, include_components_with_connected_outputs: bool = False) -> Di
619619
}
620620
return outputs
621621

622-
def show(self) -> None:
622+
def show(self, server_url: str = "https://mermaid.ink", params: Optional[dict] = None) -> None:
623623
"""
624-
If running in a Jupyter notebook, display an image representing this `Pipeline`.
624+
Display an image representing this `Pipeline` in a Jupyter notebook.
625625
626+
This function generates a diagram of the `Pipeline` using a Mermaid server and displays it directly in
627+
the notebook.
628+
629+
:param server_url:
630+
The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
631+
See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
632+
info on how to set up your own Mermaid server.
633+
634+
:param params:
635+
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
636+
Supported keys:
637+
- format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
638+
- type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
639+
- theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
640+
- bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
641+
- width: Width of the output image (integer).
642+
- height: Height of the output image (integer).
643+
- scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
644+
- fit: Whether to fit the diagram size to the page (PDF only, boolean).
645+
- paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
646+
- landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
647+
648+
:raises PipelineDrawingError:
649+
If the function is called outside of a Jupyter notebook or if there is an issue with rendering.
626650
"""
627651
if is_in_jupyter():
628652
from IPython.display import Image, display # type: ignore
629653

630-
image_data = _to_mermaid_image(self.graph)
631-
654+
image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params)
632655
display(Image(image_data))
633656
else:
634657
msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
635658
raise PipelineDrawingError(msg)
636659

637-
def draw(self, path: Path) -> None:
660+
def draw(self, path: Path, server_url: str = "https://mermaid.ink", params: Optional[dict] = None) -> None:
638661
"""
639-
Save an image representing this `Pipeline` to `path`.
662+
Save an image representing this `Pipeline` to the specified file path.
663+
664+
This function generates a diagram of the `Pipeline` using the Mermaid server and saves it to the provided path.
640665
641666
:param path:
642-
The path to save the image to.
667+
The file path where the generated image will be saved.
668+
:param server_url:
669+
The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
670+
See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
671+
info on how to set up your own Mermaid server.
672+
:param params:
673+
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
674+
Supported keys:
675+
- format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
676+
- type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
677+
- theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
678+
- bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
679+
- width: Width of the output image (integer).
680+
- height: Height of the output image (integer).
681+
- scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
682+
- fit: Whether to fit the diagram size to the page (PDF only, boolean).
683+
- paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
684+
- landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
685+
686+
:raises PipelineDrawingError:
687+
If there is an issue with rendering or saving the image.
643688
"""
644689
# Before drawing we edit a bit the graph, to avoid modifying the original that is
645690
# used for running the pipeline we copy it.
646-
image_data = _to_mermaid_image(self.graph)
691+
image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params)
647692
Path(path).write_bytes(image_data)
648693

649694
def walk(self) -> Iterator[Tuple[str, Component]]:

Diff for: haystack/core/pipeline/draw.py

+119-14
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import base64
66
import json
77
import zlib
8+
from typing import Any, Dict, Optional
89

910
import networkx # type:ignore
1011
import requests
@@ -54,7 +55,7 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
5455
ARROWHEAD_MANDATORY = "-->"
5556
ARROWHEAD_OPTIONAL = ".->"
5657
MERMAID_STYLED_TEMPLATE = """
57-
%%{{ init: {{'theme': 'neutral' }} }}%%
58+
%%{{ init: {params} }}%%
5859
5960
graph TD;
6061
@@ -64,27 +65,133 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
6465
"""
6566

6667

67-
def _to_mermaid_image(graph: networkx.MultiDiGraph):
68+
def _validate_mermaid_params(params: Dict[str, Any]) -> None:
6869
"""
69-
Renders a pipeline using Mermaid (hosted version at 'https://mermaid.ink'). Requires Internet access.
70+
Validates and sets default values for Mermaid parameters.
71+
72+
:param params:
73+
Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details.
74+
Supported keys:
75+
- format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
76+
- type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
77+
- theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
78+
- bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
79+
- width: Width of the output image (integer).
80+
- height: Height of the output image (integer).
81+
- scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
82+
- fit: Whether to fit the diagram size to the page (PDF only, boolean).
83+
- paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
84+
- landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
85+
86+
:raises ValueError:
87+
If any parameter is invalid or does not match the expected format.
88+
"""
89+
valid_img_types = {"jpeg", "png", "webp"}
90+
valid_themes = {"default", "neutral", "dark", "forest"}
91+
valid_formats = {"img", "svg", "pdf"}
92+
93+
params.setdefault("format", "img")
94+
params.setdefault("type", "png")
95+
params.setdefault("theme", "neutral")
96+
97+
if params["format"] not in valid_formats:
98+
raise ValueError(f"Invalid image format: {params['format']}. Valid options are: {valid_formats}.")
99+
100+
if params["format"] == "img" and params["type"] not in valid_img_types:
101+
raise ValueError(f"Invalid image type: {params['type']}. Valid options are: {valid_img_types}.")
102+
103+
if params["theme"] not in valid_themes:
104+
raise ValueError(f"Invalid theme: {params['theme']}. Valid options are: {valid_themes}.")
105+
106+
if "width" in params and not isinstance(params["width"], int):
107+
raise ValueError("Width must be an integer.")
108+
if "height" in params and not isinstance(params["height"], int):
109+
raise ValueError("Height must be an integer.")
110+
111+
if "scale" in params and not 1 <= params["scale"] <= 3:
112+
raise ValueError("Scale must be a number between 1 and 3.")
113+
if "scale" in params and not ("width" in params or "height" in params):
114+
raise ValueError("Scale is only allowed when width or height is set.")
115+
116+
if "bgColor" in params and not isinstance(params["bgColor"], str):
117+
raise ValueError("Background color must be a string.")
118+
119+
# PDF specific parameters
120+
if params["format"] == "pdf":
121+
if "fit" in params and not isinstance(params["fit"], bool):
122+
raise ValueError("Fit must be a boolean.")
123+
if "paper" in params and not isinstance(params["paper"], str):
124+
raise ValueError("Paper size must be a string (e.g., 'a4', 'a3').")
125+
if "landscape" in params and not isinstance(params["landscape"], bool):
126+
raise ValueError("Landscape must be a boolean.")
127+
if "fit" in params and ("paper" in params or "landscape" in params):
128+
logger.warning("`fit` overrides `paper` and `landscape` for PDFs. Ignoring `paper` and `landscape`.")
129+
130+
131+
def _to_mermaid_image(
132+
graph: networkx.MultiDiGraph, server_url: str = "https://mermaid.ink", params: Optional[dict] = None
133+
) -> bytes:
134+
"""
135+
Renders a pipeline using a Mermaid server.
136+
137+
:param graph:
138+
The graph to render as a Mermaid pipeline.
139+
:param server_url:
140+
Base URL of the Mermaid server (default: 'https://mermaid.ink').
141+
:param params:
142+
Dictionary of customization parameters. See `validate_mermaid_params` for valid keys.
143+
:returns:
144+
The image, SVG, or PDF data returned by the Mermaid server as bytes.
145+
:raises ValueError:
146+
If any parameter is invalid or does not match the expected format.
147+
:raises PipelineDrawingError:
148+
If there is an issue connecting to the Mermaid server or the server returns an error.
70149
"""
150+
151+
if params is None:
152+
params = {}
153+
154+
_validate_mermaid_params(params)
155+
156+
theme = params.get("theme")
157+
init_params = json.dumps({"theme": theme})
158+
71159
# Copy the graph to avoid modifying the original
72-
graph_styled = _to_mermaid_text(graph.copy())
160+
graph_styled = _to_mermaid_text(graph.copy(), init_params)
73161
json_string = json.dumps({"code": graph_styled})
74162

75-
# Uses the DEFLATE algorithm at the highest level for smallest size
76-
compressor = zlib.compressobj(level=9)
163+
# Compress the JSON string with zlib (RFC 1950)
164+
compressor = zlib.compressobj(level=9, wbits=15)
77165
compressed_data = compressor.compress(json_string.encode("utf-8")) + compressor.flush()
78166
compressed_url_safe_base64 = base64.urlsafe_b64encode(compressed_data).decode("utf-8").strip()
79167

80-
url = f"https://mermaid.ink/img/pako:{compressed_url_safe_base64}?type=png"
168+
# Determine the correct endpoint
169+
endpoint_format = params.get("format", "img") # Default to /img endpoint
170+
if endpoint_format not in {"img", "svg", "pdf"}:
171+
raise ValueError(f"Invalid format: {endpoint_format}. Valid options are 'img', 'svg', or 'pdf'.")
172+
173+
# Construct the URL without query parameters
174+
url = f"{server_url}/{endpoint_format}/pako:{compressed_url_safe_base64}"
175+
176+
# Add query parameters adhering to mermaid.ink documentation
177+
query_params = []
178+
for key, value in params.items():
179+
if key not in {"theme", "format"}: # Exclude theme (handled in init_params) and format (endpoint-specific)
180+
if value is True:
181+
query_params.append(f"{key}")
182+
else:
183+
query_params.append(f"{key}={value}")
184+
185+
if query_params:
186+
url += "?" + "&".join(query_params)
81187

82188
logger.debug("Rendering graph at {url}", url=url)
83189
try:
84190
resp = requests.get(url, timeout=10)
85191
if resp.status_code >= 400:
86192
logger.warning(
87-
"Failed to draw the pipeline: https://mermaid.ink/img/ returned status {status_code}",
193+
"Failed to draw the pipeline: {server_url} returned status {status_code}",
194+
server_url=server_url,
88195
status_code=resp.status_code,
89196
)
90197
logger.info("Exact URL requested: {url}", url=url)
@@ -93,18 +200,16 @@ def _to_mermaid_image(graph: networkx.MultiDiGraph):
93200

94201
except Exception as exc: # pylint: disable=broad-except
95202
logger.warning(
96-
"Failed to draw the pipeline: could not connect to https://mermaid.ink/img/ ({error})", error=exc
203+
"Failed to draw the pipeline: could not connect to {server_url} ({error})", server_url=server_url, error=exc
97204
)
98205
logger.info("Exact URL requested: {url}", url=url)
99206
logger.warning("No pipeline diagram will be saved.")
100-
raise PipelineDrawingError(
101-
"There was an issue with https://mermaid.ink/, see the stacktrace for details."
102-
) from exc
207+
raise PipelineDrawingError(f"There was an issue with {server_url}, see the stacktrace for details.") from exc
103208

104209
return resp.content
105210

106211

107-
def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str:
212+
def _to_mermaid_text(graph: networkx.MultiDiGraph, init_params: str) -> str:
108213
"""
109214
Converts a Networkx graph into Mermaid syntax.
110215
@@ -153,7 +258,7 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str:
153258
]
154259
connections = "\n".join(connections_list + input_connections + output_connections)
155260

156-
graph_styled = MERMAID_STYLED_TEMPLATE.format(connections=connections)
261+
graph_styled = MERMAID_STYLED_TEMPLATE.format(params=init_params, connections=connections)
157262
logger.debug("Mermaid diagram:\n{diagram}", diagram=graph_styled)
158263

159264
return graph_styled
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
3+
features:
4+
- |
5+
Drawing pipelines, i.e.: calls to draw() or show(), can now be done using a custom Mermaid server and additional parameters. This allows for more flexibility in how pipelines are rendered. See Mermaid.ink's [documentation](https://github.com/jihchi/mermaid.ink) for more information on how to set up a custom server.

0 commit comments

Comments
 (0)