3636 DrawablePregroup ,
3737 LEDGE ,
3838 WireEndpointType )
39- from lambeq .backend .drawing .drawing_backend import (DEFAULT_ASPECT ,
39+ from lambeq .backend .drawing .drawing_backend import (ColoringMode ,
40+ DEFAULT_ASPECT ,
4041 DEFAULT_MARGINS ,
4142 DrawingBackend ,
42- FRAME_COLORS_GENERATOR )
43+ FRAME_COLORS )
4344from lambeq .backend .drawing .helpers import drawn_as_spider , needs_asymmetry
4445from lambeq .backend .drawing .mat_backend import MatBackend
4546from lambeq .backend .drawing .text_printer import PregroupTextPrinter
@@ -106,6 +107,9 @@ def draw(diagram: Diagram, **params) -> None:
106107 params ['color_boxes' ] = params .get (
107108 'color_boxes' , diagram .has_frames ,
108109 )
110+ params ['coloring_mode' ] = params .get (
111+ 'coloring_mode' , ColoringMode .TYPE ,
112+ )
109113 if drawable is None :
110114 drawable = drawable_cls .from_diagram (diagram ,
111115 params .get ('foliated' , False ))
@@ -422,11 +426,9 @@ def _draw_box(backend: DrawingBackend,
422426 else :
423427 points [2 ][0 ] += asymmetry
424428
425- color = 'white'
426- if (params ['color_boxes' ]
427- and isinstance (drawable_diagram , DrawableDiagramWithFrames )
428- and hasattr (box , 'name' ) and box .name ):
429- color = next (FRAME_COLORS_GENERATOR )
429+ color = _get_box_color (box ,
430+ color_boxes = params ['color_boxes' ],
431+ coloring_mode = params ['coloring_mode' ])
430432 backend .draw_polygon (* points , color = color )
431433
432434 if params .get ('draw_box_labels' , True ) and hasattr (box , 'name' ):
@@ -440,6 +442,24 @@ def _draw_box(backend: DrawingBackend,
440442 return backend
441443
442444
445+ def _get_box_color (box : grammar .Diagrammable ,
446+ color_boxes : bool = False ,
447+ coloring_mode : ColoringMode = ColoringMode .TYPE ):
448+ color = 'white'
449+ if color_boxes :
450+ if hasattr (box , 'name' ):
451+ color = 'gray'
452+
453+ if isinstance (box , grammar .Frame ) and hasattr (box , 'name' ):
454+ frame_attr = getattr (box , f'frame_{ coloring_mode } ' )
455+ if coloring_mode == ColoringMode .TYPE :
456+ frame_attr += (len (FRAME_COLORS ) // 7 ) * (box .frame_order - 1 )
457+
458+ color = FRAME_COLORS [(frame_attr - 1 ) % len (FRAME_COLORS )]
459+
460+ return color
461+
462+
443463def _draw_pregroup_state (backend : DrawingBackend ,
444464 drawable_box : BoxNode ,
445465 ** params ) -> DrawingBackend :
0 commit comments