Skip to content

Commit

Permalink
Implement coloring modes (#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
neiljdo authored Nov 6, 2024
1 parent 55dac06 commit 9e2aa45
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 7 deletions.
34 changes: 27 additions & 7 deletions lambeq/backend/drawing/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@
DrawablePregroup,
LEDGE,
WireEndpointType)
from lambeq.backend.drawing.drawing_backend import (DEFAULT_ASPECT,
from lambeq.backend.drawing.drawing_backend import (ColoringMode,
DEFAULT_ASPECT,
DEFAULT_MARGINS,
DrawingBackend,
FRAME_COLORS_GENERATOR)
FRAME_COLORS)
from lambeq.backend.drawing.helpers import drawn_as_spider, needs_asymmetry
from lambeq.backend.drawing.mat_backend import MatBackend
from lambeq.backend.drawing.text_printer import PregroupTextPrinter
Expand Down Expand Up @@ -106,6 +107,9 @@ def draw(diagram: Diagram, **params) -> None:
params['color_boxes'] = params.get(
'color_boxes', diagram.has_frames,
)
params['coloring_mode'] = params.get(
'coloring_mode', ColoringMode.TYPE,
)
if drawable is None:
drawable = drawable_cls.from_diagram(diagram,
params.get('foliated', False))
Expand Down Expand Up @@ -422,11 +426,9 @@ def _draw_box(backend: DrawingBackend,
else:
points[2][0] += asymmetry

color = 'white'
if (params['color_boxes']
and isinstance(drawable_diagram, DrawableDiagramWithFrames)
and hasattr(box, 'name') and box.name):
color = next(FRAME_COLORS_GENERATOR)
color = _get_box_color(box,
color_boxes=params['color_boxes'],
coloring_mode=params['coloring_mode'])
backend.draw_polygon(*points, color=color)

if params.get('draw_box_labels', True) and hasattr(box, 'name'):
Expand All @@ -440,6 +442,24 @@ def _draw_box(backend: DrawingBackend,
return backend


def _get_box_color(box: grammar.Diagrammable,
color_boxes: bool = False,
coloring_mode: ColoringMode = ColoringMode.TYPE):
color = 'white'
if color_boxes:
if hasattr(box, 'name'):
color = 'gray'

if isinstance(box, grammar.Frame) and hasattr(box, 'name'):
frame_attr = getattr(box, f'frame_{coloring_mode}')
if coloring_mode == ColoringMode.TYPE:
frame_attr += (len(FRAME_COLORS) // 7) * (box.frame_order - 1)

color = FRAME_COLORS[(frame_attr - 1) % len(FRAME_COLORS)]

return color


def _draw_pregroup_state(backend: DrawingBackend,
drawable_box: BoxNode,
**params) -> DrawingBackend:
Expand Down
22 changes: 22 additions & 0 deletions lambeq/backend/drawing/drawing_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from enum import Enum
import itertools

from lambeq.backend.drawing.drawable import DrawableDiagram
Expand Down Expand Up @@ -48,6 +49,7 @@
'blue': '#776ff3',
'yellow': '#f7f700',
'black': '#000000',
'gray': '#e0e0e0'
}
for color in FRAME_COLORS:
COLORS[color] = color
Expand All @@ -62,6 +64,26 @@
}


class ColoringMode(str, Enum):
"""An enumeration for the coloring modes when coloring is used.
Frames can be colored by:
.. glossary::
TYPE
The number of holes in the frame
ORDER
The level of nesting of the frame, increasing from
the inside going outward.
"""

TYPE = 'type'
ORDER = 'order'


class DrawingBackend(ABC):
""" Abstract drawing backend. """

Expand Down
13 changes: 13 additions & 0 deletions lambeq/backend/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2072,6 +2072,19 @@ def dagger(self) -> DaggeredFrame | Frame:
def __hash__(self) -> int:
return hash(repr(self))

@property
def frame_type(self):
"""The number of holes in the frame."""
return len(self.components)

@property
def frame_order(self):
"""The level of nesting in the frame increasing from the inside
going outward."""
component_frame_orders = [c.frame_order if isinstance(c, Frame) else 0
for c in self.components]
return max(component_frame_orders) + 1

@classmethod
def from_json(cls, data: _JSONDictT | str) -> Self:
"""Decode a JSON object or string into a
Expand Down
10 changes: 10 additions & 0 deletions tests/backend/test_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,16 @@ def test_frame():
)
assert f.name == 'f1'
assert len(f.components) == 4
assert f.frame_type == 4
assert f.frame_order == 1

f2 = Frame('f2', n, n, components=[f, f])
f3 = Frame('f3', n @ n, n @ n, components=[f2])

assert f2.frame_type == 2
assert f2.frame_order == 2
assert f3.frame_type == 1
assert f3.frame_order == 3


def test_diagram_has_frame():
Expand Down

0 comments on commit 9e2aa45

Please sign in to comment.