55import base64
66import json
77import zlib
8+ from typing import Any , Dict , Optional
89
910import networkx # type:ignore
1011import requests
@@ -54,7 +55,7 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
5455ARROWHEAD_MANDATORY = "-->"
5556ARROWHEAD_OPTIONAL = ".->"
5657MERMAID_STYLED_TEMPLATE = """
57- %%{{ init: {{'theme': 'neutral' } } }}%%
58+ %%{{ init: {params } }}%%
5859
5960graph TD;
6061
@@ -64,27 +65,133 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph) -> networkx.MultiDiGraph:
6465"""
6566
6667
67- def _to_mermaid_image ( graph : networkx . MultiDiGraph ) :
68+ def _validate_mermaid_params ( params : Dict [ str , Any ]) -> None :
6869 """
69- Renders a pipeline using Mermaid (hosted version at 'https://mermaid.ink'). Requires Internet access.
70+ Validates and sets default values for Mermaid parameters.
71+
72+ :param params:
73+ Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details.
74+ Supported keys:
75+ - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
76+ - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
77+ - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
78+ - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
79+ - width: Width of the output image (integer).
80+ - height: Height of the output image (integer).
81+ - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
82+ - fit: Whether to fit the diagram size to the page (PDF only, boolean).
83+ - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
84+ - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
85+
86+ :raises ValueError:
87+ If any parameter is invalid or does not match the expected format.
88+ """
89+ valid_img_types = {"jpeg" , "png" , "webp" }
90+ valid_themes = {"default" , "neutral" , "dark" , "forest" }
91+ valid_formats = {"img" , "svg" , "pdf" }
92+
93+ params .setdefault ("format" , "img" )
94+ params .setdefault ("type" , "png" )
95+ params .setdefault ("theme" , "neutral" )
96+
97+ if params ["format" ] not in valid_formats :
98+ raise ValueError (f"Invalid image format: { params ['format' ]} . Valid options are: { valid_formats } ." )
99+
100+ if params ["format" ] == "img" and params ["type" ] not in valid_img_types :
101+ raise ValueError (f"Invalid image type: { params ['type' ]} . Valid options are: { valid_img_types } ." )
102+
103+ if params ["theme" ] not in valid_themes :
104+ raise ValueError (f"Invalid theme: { params ['theme' ]} . Valid options are: { valid_themes } ." )
105+
106+ if "width" in params and not isinstance (params ["width" ], int ):
107+ raise ValueError ("Width must be an integer." )
108+ if "height" in params and not isinstance (params ["height" ], int ):
109+ raise ValueError ("Height must be an integer." )
110+
111+ if "scale" in params and not 1 <= params ["scale" ] <= 3 :
112+ raise ValueError ("Scale must be a number between 1 and 3." )
113+ if "scale" in params and not ("width" in params or "height" in params ):
114+ raise ValueError ("Scale is only allowed when width or height is set." )
115+
116+ if "bgColor" in params and not isinstance (params ["bgColor" ], str ):
117+ raise ValueError ("Background color must be a string." )
118+
119+ # PDF specific parameters
120+ if params ["format" ] == "pdf" :
121+ if "fit" in params and not isinstance (params ["fit" ], bool ):
122+ raise ValueError ("Fit must be a boolean." )
123+ if "paper" in params and not isinstance (params ["paper" ], str ):
124+ raise ValueError ("Paper size must be a string (e.g., 'a4', 'a3')." )
125+ if "landscape" in params and not isinstance (params ["landscape" ], bool ):
126+ raise ValueError ("Landscape must be a boolean." )
127+ if "fit" in params and ("paper" in params or "landscape" in params ):
128+ logger .warning ("`fit` overrides `paper` and `landscape` for PDFs. Ignoring `paper` and `landscape`." )
129+
130+
131+ def _to_mermaid_image (
132+ graph : networkx .MultiDiGraph , server_url : str = "https://mermaid.ink" , params : Optional [dict ] = None
133+ ) -> bytes :
134+ """
135+ Renders a pipeline using a Mermaid server.
136+
137+ :param graph:
138+ The graph to render as a Mermaid pipeline.
139+ :param server_url:
140+ Base URL of the Mermaid server (default: 'https://mermaid.ink').
141+ :param params:
142+ Dictionary of customization parameters. See `validate_mermaid_params` for valid keys.
143+ :returns:
144+ The image, SVG, or PDF data returned by the Mermaid server as bytes.
145+ :raises ValueError:
146+ If any parameter is invalid or does not match the expected format.
147+ :raises PipelineDrawingError:
148+ If there is an issue connecting to the Mermaid server or the server returns an error.
70149 """
150+
151+ if params is None :
152+ params = {}
153+
154+ _validate_mermaid_params (params )
155+
156+ theme = params .get ("theme" )
157+ init_params = json .dumps ({"theme" : theme })
158+
71159 # Copy the graph to avoid modifying the original
72- graph_styled = _to_mermaid_text (graph .copy ())
160+ graph_styled = _to_mermaid_text (graph .copy (), init_params )
73161 json_string = json .dumps ({"code" : graph_styled })
74162
75- # Uses the DEFLATE algorithm at the highest level for smallest size
76- compressor = zlib .compressobj (level = 9 )
163+ # Compress the JSON string with zlib (RFC 1950)
164+ compressor = zlib .compressobj (level = 9 , wbits = 15 )
77165 compressed_data = compressor .compress (json_string .encode ("utf-8" )) + compressor .flush ()
78166 compressed_url_safe_base64 = base64 .urlsafe_b64encode (compressed_data ).decode ("utf-8" ).strip ()
79167
80- url = f"https://mermaid.ink/img/pako:{ compressed_url_safe_base64 } ?type=png"
168+ # Determine the correct endpoint
169+ endpoint_format = params .get ("format" , "img" ) # Default to /img endpoint
170+ if endpoint_format not in {"img" , "svg" , "pdf" }:
171+ raise ValueError (f"Invalid format: { endpoint_format } . Valid options are 'img', 'svg', or 'pdf'." )
172+
173+ # Construct the URL without query parameters
174+ url = f"{ server_url } /{ endpoint_format } /pako:{ compressed_url_safe_base64 } "
175+
176+ # Add query parameters adhering to mermaid.ink documentation
177+ query_params = []
178+ for key , value in params .items ():
179+ if key not in {"theme" , "format" }: # Exclude theme (handled in init_params) and format (endpoint-specific)
180+ if value is True :
181+ query_params .append (f"{ key } " )
182+ else :
183+ query_params .append (f"{ key } ={ value } " )
184+
185+ if query_params :
186+ url += "?" + "&" .join (query_params )
81187
82188 logger .debug ("Rendering graph at {url}" , url = url )
83189 try :
84190 resp = requests .get (url , timeout = 10 )
85191 if resp .status_code >= 400 :
86192 logger .warning (
87- "Failed to draw the pipeline: https://mermaid.ink/img/ returned status {status_code}" ,
193+ "Failed to draw the pipeline: {server_url} returned status {status_code}" ,
194+ server_url = server_url ,
88195 status_code = resp .status_code ,
89196 )
90197 logger .info ("Exact URL requested: {url}" , url = url )
@@ -93,18 +200,16 @@ def _to_mermaid_image(graph: networkx.MultiDiGraph):
93200
94201 except Exception as exc : # pylint: disable=broad-except
95202 logger .warning (
96- "Failed to draw the pipeline: could not connect to https://mermaid.ink/img/ ({error})" , error = exc
203+ "Failed to draw the pipeline: could not connect to {server_url} ({error})" , server_url = server_url , error = exc
97204 )
98205 logger .info ("Exact URL requested: {url}" , url = url )
99206 logger .warning ("No pipeline diagram will be saved." )
100- raise PipelineDrawingError (
101- "There was an issue with https://mermaid.ink/, see the stacktrace for details."
102- ) from exc
207+ raise PipelineDrawingError (f"There was an issue with { server_url } , see the stacktrace for details." ) from exc
103208
104209 return resp .content
105210
106211
107- def _to_mermaid_text (graph : networkx .MultiDiGraph ) -> str :
212+ def _to_mermaid_text (graph : networkx .MultiDiGraph , init_params : str ) -> str :
108213 """
109214 Converts a Networkx graph into Mermaid syntax.
110215
@@ -153,7 +258,7 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str:
153258 ]
154259 connections = "\n " .join (connections_list + input_connections + output_connections )
155260
156- graph_styled = MERMAID_STYLED_TEMPLATE .format (connections = connections )
261+ graph_styled = MERMAID_STYLED_TEMPLATE .format (params = init_params , connections = connections )
157262 logger .debug ("Mermaid diagram:\n {diagram}" , diagram = graph_styled )
158263
159264 return graph_styled
0 commit comments