Skip to content

Commit 9e2aa45

Browse files
authored
Implement coloring modes (#175)
1 parent 55dac06 commit 9e2aa45

File tree

4 files changed

+72
-7
lines changed

4 files changed

+72
-7
lines changed

lambeq/backend/drawing/drawing.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@
3636
DrawablePregroup,
3737
LEDGE,
3838
WireEndpointType)
39-
from lambeq.backend.drawing.drawing_backend import (DEFAULT_ASPECT,
39+
from lambeq.backend.drawing.drawing_backend import (ColoringMode,
40+
DEFAULT_ASPECT,
4041
DEFAULT_MARGINS,
4142
DrawingBackend,
42-
FRAME_COLORS_GENERATOR)
43+
FRAME_COLORS)
4344
from lambeq.backend.drawing.helpers import drawn_as_spider, needs_asymmetry
4445
from lambeq.backend.drawing.mat_backend import MatBackend
4546
from lambeq.backend.drawing.text_printer import PregroupTextPrinter
@@ -106,6 +107,9 @@ def draw(diagram: Diagram, **params) -> None:
106107
params['color_boxes'] = params.get(
107108
'color_boxes', diagram.has_frames,
108109
)
110+
params['coloring_mode'] = params.get(
111+
'coloring_mode', ColoringMode.TYPE,
112+
)
109113
if drawable is None:
110114
drawable = drawable_cls.from_diagram(diagram,
111115
params.get('foliated', False))
@@ -422,11 +426,9 @@ def _draw_box(backend: DrawingBackend,
422426
else:
423427
points[2][0] += asymmetry
424428

425-
color = 'white'
426-
if (params['color_boxes']
427-
and isinstance(drawable_diagram, DrawableDiagramWithFrames)
428-
and hasattr(box, 'name') and box.name):
429-
color = next(FRAME_COLORS_GENERATOR)
429+
color = _get_box_color(box,
430+
color_boxes=params['color_boxes'],
431+
coloring_mode=params['coloring_mode'])
430432
backend.draw_polygon(*points, color=color)
431433

432434
if params.get('draw_box_labels', True) and hasattr(box, 'name'):
@@ -440,6 +442,24 @@ def _draw_box(backend: DrawingBackend,
440442
return backend
441443

442444

445+
def _get_box_color(box: grammar.Diagrammable,
446+
color_boxes: bool = False,
447+
coloring_mode: ColoringMode = ColoringMode.TYPE):
448+
color = 'white'
449+
if color_boxes:
450+
if hasattr(box, 'name'):
451+
color = 'gray'
452+
453+
if isinstance(box, grammar.Frame) and hasattr(box, 'name'):
454+
frame_attr = getattr(box, f'frame_{coloring_mode}')
455+
if coloring_mode == ColoringMode.TYPE:
456+
frame_attr += (len(FRAME_COLORS) // 7) * (box.frame_order - 1)
457+
458+
color = FRAME_COLORS[(frame_attr - 1) % len(FRAME_COLORS)]
459+
460+
return color
461+
462+
443463
def _draw_pregroup_state(backend: DrawingBackend,
444464
drawable_box: BoxNode,
445465
**params) -> DrawingBackend:

lambeq/backend/drawing/drawing_backend.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from __future__ import annotations
2222

2323
from abc import ABC, abstractmethod
24+
from enum import Enum
2425
import itertools
2526

2627
from lambeq.backend.drawing.drawable import DrawableDiagram
@@ -48,6 +49,7 @@
4849
'blue': '#776ff3',
4950
'yellow': '#f7f700',
5051
'black': '#000000',
52+
'gray': '#e0e0e0'
5153
}
5254
for color in FRAME_COLORS:
5355
COLORS[color] = color
@@ -62,6 +64,26 @@
6264
}
6365

6466

67+
class ColoringMode(str, Enum):
68+
"""An enumeration for the coloring modes when coloring is used.
69+
70+
Frames can be colored by:
71+
72+
.. glossary::
73+
74+
TYPE
75+
The number of holes in the frame
76+
77+
ORDER
78+
The level of nesting of the frame, increasing from
79+
the inside going outward.
80+
81+
"""
82+
83+
TYPE = 'type'
84+
ORDER = 'order'
85+
86+
6587
class DrawingBackend(ABC):
6688
""" Abstract drawing backend. """
6789

lambeq/backend/grammar.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2072,6 +2072,19 @@ def dagger(self) -> DaggeredFrame | Frame:
20722072
def __hash__(self) -> int:
20732073
return hash(repr(self))
20742074

2075+
@property
2076+
def frame_type(self):
2077+
"""The number of holes in the frame."""
2078+
return len(self.components)
2079+
2080+
@property
2081+
def frame_order(self):
2082+
"""The level of nesting in the frame increasing from the inside
2083+
going outward."""
2084+
component_frame_orders = [c.frame_order if isinstance(c, Frame) else 0
2085+
for c in self.components]
2086+
return max(component_frame_orders) + 1
2087+
20752088
@classmethod
20762089
def from_json(cls, data: _JSONDictT | str) -> Self:
20772090
"""Decode a JSON object or string into a

tests/backend/test_grammar.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,16 @@ def test_frame():
465465
)
466466
assert f.name == 'f1'
467467
assert len(f.components) == 4
468+
assert f.frame_type == 4
469+
assert f.frame_order == 1
470+
471+
f2 = Frame('f2', n, n, components=[f, f])
472+
f3 = Frame('f3', n @ n, n @ n, components=[f2])
473+
474+
assert f2.frame_type == 2
475+
assert f2.frame_order == 2
476+
assert f3.frame_type == 1
477+
assert f3.frame_order == 3
468478

469479

470480
def test_diagram_has_frame():

0 commit comments

Comments
 (0)