From 9e2aa4595b351bea780674f6aa6ea0b753a9e8b3 Mon Sep 17 00:00:00 2001 From: Neil Ortega Date: Thu, 7 Nov 2024 02:09:04 +0900 Subject: [PATCH] Implement coloring modes (#175) --- lambeq/backend/drawing/drawing.py | 34 ++++++++++++++++++----- lambeq/backend/drawing/drawing_backend.py | 22 +++++++++++++++ lambeq/backend/grammar.py | 13 +++++++++ tests/backend/test_grammar.py | 10 +++++++ 4 files changed, 72 insertions(+), 7 deletions(-) diff --git a/lambeq/backend/drawing/drawing.py b/lambeq/backend/drawing/drawing.py index 4dbd0f23..bf85b75b 100644 --- a/lambeq/backend/drawing/drawing.py +++ b/lambeq/backend/drawing/drawing.py @@ -36,10 +36,11 @@ DrawablePregroup, LEDGE, WireEndpointType) -from lambeq.backend.drawing.drawing_backend import (DEFAULT_ASPECT, +from lambeq.backend.drawing.drawing_backend import (ColoringMode, + DEFAULT_ASPECT, DEFAULT_MARGINS, DrawingBackend, - FRAME_COLORS_GENERATOR) + FRAME_COLORS) from lambeq.backend.drawing.helpers import drawn_as_spider, needs_asymmetry from lambeq.backend.drawing.mat_backend import MatBackend from lambeq.backend.drawing.text_printer import PregroupTextPrinter @@ -106,6 +107,9 @@ def draw(diagram: Diagram, **params) -> None: params['color_boxes'] = params.get( 'color_boxes', diagram.has_frames, ) + params['coloring_mode'] = params.get( + 'coloring_mode', ColoringMode.TYPE, + ) if drawable is None: drawable = drawable_cls.from_diagram(diagram, params.get('foliated', False)) @@ -422,11 +426,9 @@ def _draw_box(backend: DrawingBackend, else: points[2][0] += asymmetry - color = 'white' - if (params['color_boxes'] - and isinstance(drawable_diagram, DrawableDiagramWithFrames) - and hasattr(box, 'name') and box.name): - color = next(FRAME_COLORS_GENERATOR) + color = _get_box_color(box, + color_boxes=params['color_boxes'], + coloring_mode=params['coloring_mode']) backend.draw_polygon(*points, color=color) if params.get('draw_box_labels', True) and hasattr(box, 'name'): @@ -440,6 +442,24 @@ def _draw_box(backend: DrawingBackend, return backend +def _get_box_color(box: grammar.Diagrammable, + color_boxes: bool = False, + coloring_mode: ColoringMode = ColoringMode.TYPE): + color = 'white' + if color_boxes: + if hasattr(box, 'name'): + color = 'gray' + + if isinstance(box, grammar.Frame) and hasattr(box, 'name'): + frame_attr = getattr(box, f'frame_{coloring_mode}') + if coloring_mode == ColoringMode.TYPE: + frame_attr += (len(FRAME_COLORS) // 7) * (box.frame_order - 1) + + color = FRAME_COLORS[(frame_attr - 1) % len(FRAME_COLORS)] + + return color + + def _draw_pregroup_state(backend: DrawingBackend, drawable_box: BoxNode, **params) -> DrawingBackend: diff --git a/lambeq/backend/drawing/drawing_backend.py b/lambeq/backend/drawing/drawing_backend.py index f8acf750..c99b24f3 100644 --- a/lambeq/backend/drawing/drawing_backend.py +++ b/lambeq/backend/drawing/drawing_backend.py @@ -21,6 +21,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from enum import Enum import itertools from lambeq.backend.drawing.drawable import DrawableDiagram @@ -48,6 +49,7 @@ 'blue': '#776ff3', 'yellow': '#f7f700', 'black': '#000000', + 'gray': '#e0e0e0' } for color in FRAME_COLORS: COLORS[color] = color @@ -62,6 +64,26 @@ } +class ColoringMode(str, Enum): + """An enumeration for the coloring modes when coloring is used. + + Frames can be colored by: + + .. glossary:: + + TYPE + The number of holes in the frame + + ORDER + The level of nesting of the frame, increasing from + the inside going outward. + + """ + + TYPE = 'type' + ORDER = 'order' + + class DrawingBackend(ABC): """ Abstract drawing backend. """ diff --git a/lambeq/backend/grammar.py b/lambeq/backend/grammar.py index b621c31f..35a2aec5 100644 --- a/lambeq/backend/grammar.py +++ b/lambeq/backend/grammar.py @@ -2072,6 +2072,19 @@ def dagger(self) -> DaggeredFrame | Frame: def __hash__(self) -> int: return hash(repr(self)) + @property + def frame_type(self): + """The number of holes in the frame.""" + return len(self.components) + + @property + def frame_order(self): + """The level of nesting in the frame increasing from the inside + going outward.""" + component_frame_orders = [c.frame_order if isinstance(c, Frame) else 0 + for c in self.components] + return max(component_frame_orders) + 1 + @classmethod def from_json(cls, data: _JSONDictT | str) -> Self: """Decode a JSON object or string into a diff --git a/tests/backend/test_grammar.py b/tests/backend/test_grammar.py index c5f7902f..21520053 100644 --- a/tests/backend/test_grammar.py +++ b/tests/backend/test_grammar.py @@ -465,6 +465,16 @@ def test_frame(): ) assert f.name == 'f1' assert len(f.components) == 4 + assert f.frame_type == 4 + assert f.frame_order == 1 + + f2 = Frame('f2', n, n, components=[f, f]) + f3 = Frame('f3', n @ n, n @ n, components=[f2]) + + assert f2.frame_type == 2 + assert f2.frame_order == 2 + assert f3.frame_type == 1 + assert f3.frame_order == 3 def test_diagram_has_frame():