Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to colour wires in diagrams with frames #177

Merged
merged 44 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
14b8f85
Color wires changes. wires are coloured used by default when the diag…
Ragunath1729 Nov 7, 2024
ba2230b
Merge remote-tracking branch 'refs/remotes/public-lambeq/main' into b…
Ragunath1729 Nov 7, 2024
0b3f275
review comments fixes, remove colors looks like similar
Ragunath1729 Nov 8, 2024
84b4456
rearrange colors to avoid similar shades comes together
Ragunath1729 Nov 8, 2024
8b963d7
lint error fixes, flake8 fixes
Ragunath1729 Nov 8, 2024
885c9cb
flake8 fixes
Ragunath1729 Nov 8, 2024
72b635f
inherit add_box to avoid side-effect in non-frame diagrams
Ragunath1729 Nov 11, 2024
9dbcc40
remove debug logs
Ragunath1729 Nov 11, 2024
3a8f66d
add color params to draw method in Tikzit backend
Ragunath1729 Nov 11, 2024
3ef64a5
lint error fixes
Ragunath1729 Nov 11, 2024
1a1f2a0
mypy suggestion fixes
Ragunath1729 Nov 11, 2024
94f6dbd
demo on simple-book notebook
Ragunath1729 Nov 11, 2024
d89ec0c
change signature to fix flake8 error
Ragunath1729 Nov 11, 2024
8093b95
add doc strings to new functions
Ragunath1729 Nov 11, 2024
69d9830
make noun_id non-static variable
Ragunath1729 Nov 12, 2024
3af9164
remove different colours and remove colours with similar shades
Ragunath1729 Nov 12, 2024
6bf61c8
remove yellow shades from colors list
Ragunath1729 Nov 13, 2024
222376a
remove optional for input_nouns
Ragunath1729 Nov 13, 2024
d1a7100
different linewidth for box and wires, get from user using params
Ragunath1729 Nov 13, 2024
8da64d7
remove invalid characters from stylenames which throw render errors
Ragunath1729 Nov 14, 2024
3c83b83
color wire changes in Tikz_backend
Ragunath1729 Nov 14, 2024
e8c04f3
remove extra {}
Ragunath1729 Nov 15, 2024
7caef28
static commands for tex file, new layer for labels
Ragunath1729 Nov 15, 2024
3d9be55
flake8 suggestion to remove quotes
Ragunath1729 Nov 15, 2024
57fe2d9
remove extra {}
Ragunath1729 Nov 18, 2024
edac6cd
delete sample file
Ragunath1729 Nov 18, 2024
353df2f
cosmetic readability fixes
Ragunath1729 Nov 19, 2024
8d43b79
refactor wire colour formatting like boxes
Ragunath1729 Nov 19, 2024
78fbfee
wires width changes for Tikz backend
Ragunath1729 Nov 19, 2024
ba8dc7f
fix for flake8 import order error
Ragunath1729 Nov 19, 2024
ae69857
update testcases for frame drawing with wire colouring
Ragunath1729 Nov 20, 2024
1f11612
Merge branch 'main' into br_color_wires
Ragunath1729 Nov 20, 2024
c829e4f
change label flag as hidden by default
Ragunath1729 Nov 20, 2024
99ed5e4
Fix code formatting
neiljdo Nov 21, 2024
5ee51b2
Update constants
neiljdo Nov 21, 2024
294cd72
Update tests for Tikz circuit
neiljdo Nov 21, 2024
02b35f3
Update tests for Tikz diagrams
neiljdo Nov 21, 2024
697a0ee
add box_linewidth to Tikz backend
Ragunath1729 Nov 21, 2024
5d18c22
update frame-drawing tests with linewidth changes
Ragunath1729 Nov 21, 2024
93a84bb
update test-drawing tests with linewidth changes
Ragunath1729 Nov 21, 2024
0f3c286
update circuit-drawing tests with linewidth changes
Ragunath1729 Nov 21, 2024
e328cd8
Merge remote-tracking branch 'ragu/br_color_wires' into br_color_wires
Ragunath1729 Nov 21, 2024
448107b
make mypy happy
Ragunath1729 Nov 21, 2024
9f3a9a8
Add `box_linewidth` as param for `MatBackend`
neiljdo Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 154 additions & 4 deletions lambeq/backend/drawing/drawable.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class WireEndpoint:

x: float
y: float

noun_id: int = 0 # New attribute for wire noun
parent: Optional['BoxNode'] = None

@property
Expand Down Expand Up @@ -778,6 +778,8 @@ class DrawableDiagramWithFrames(DrawableDiagram):
frame, carrying all information necessary to render it.

"""
# add counter for Nouns
noun_id_counter = 1

def _make_space(self,
scan: list[int],
Expand Down Expand Up @@ -841,6 +843,120 @@ def _make_space(self,

return x, y

def _add_box_with_nouns(
self,
scan: list[int],
box: grammar.Box,
off: int,
x_pos: float,
y_pos: float,
input_nouns: list[int]
) -> tuple[list[int], int, list[int]]:
"""Add a box to the graph, creating necessary wire endpoints.

Returns
-------
list : int
The new scan of wire endpoints after adding the box
box_ind : int
The index of the newly added `BoxNode`
input_nouns : list[int]
The new order of input_nouns after adding the box
"""
node = BoxNode(box, x_pos, y_pos)

box_ind = self._add_boxnode(node)
num_input = len(box.dom)
input_nouns = input_nouns or []
for i in range(num_input):
if i < len(input_nouns):
pass
else:
# If we run out of input nouns, generate new ones
new_color = self.get_noun_id()
input_nouns.append(new_color)

# Create a node representing each element in the box's domain
for i, obj in enumerate(box.dom):
idx = off + i
nbr_idx = scan[off + i]
noun_id = (input_nouns[idx] if input_nouns
and idx < len(input_nouns)
else self.get_noun_id()
) # generate new noun id if needed

wire_end = WireEndpoint(WireEndpointType.DOM,
obj=obj,
x=self.wire_endpoints[nbr_idx].x,
y=y_pos + HALF_BOX_HEIGHT,
noun_id=noun_id)

wire_idx = self._add_wire_end(wire_end)
node.add_dom_wire(wire_idx)
self._add_wire(nbr_idx, wire_idx)

scan_insert = []
# if Swap, exchange the noun_ids
if isinstance(box, grammar.Swap):
if input_nouns and len(box.dom) > 1:
dom_idx_1 = off
dom_idx_2 = off + 1
input_nouns[dom_idx_1], input_nouns[dom_idx_2] = (
input_nouns[dom_idx_2], input_nouns[dom_idx_1])
# if Spider, expand or shrink the noun_ids based on type
elif isinstance(node.obj, grammar.Spider):
if len(box.dom) == 1 and len(box.cod) > 1:
dom_noun = (input_nouns[off] if input_nouns
and off < len(input_nouns)
else self.get_noun_id()
)
expanded_colors = [dom_noun] * len(box.cod)
input_nouns = (input_nouns[:off] + expanded_colors
+ input_nouns[off + len(box.dom):])
elif len(box.dom) > 1 and len(box.cod) == 1:
cod_noun = (input_nouns[off] if input_nouns
and off < len(input_nouns)
else self.get_noun_id()
)
input_nouns = (input_nouns[:off] + [cod_noun]
+ input_nouns[off + len(box.dom):])

num_output = off+len(box.cod)
for i in range(num_output):
if i < len(input_nouns):
pass
else:
# If we run out of input nouns, generate new ones
new_color = self.get_noun_id()
input_nouns.append(new_color)
# Create a node representing each element in the box's codomain
for i, obj in enumerate(box.cod):

# If the box is a quantum gate, retain x coordinate of wires
if box.category == quantum and len(box.dom) == len(box.cod):
nbr_idx = scan[off + i]
x = self.wire_endpoints[nbr_idx].x
else:
x = x_pos + X_SPACING * (i - len(box.cod[1:]) / 2)
y = y_pos - HALF_BOX_HEIGHT
idx = off + i
noun_id = (input_nouns[idx] if input_nouns
and idx < len(input_nouns)
else self.get_noun_id())
wire_end = WireEndpoint(WireEndpointType.COD,
obj=obj,
x=x,
y=y,
noun_id=noun_id)

wire_idx = self._add_wire_end(wire_end)
scan_insert.append(wire_idx)
node.add_cod_wire(wire_idx)

# Replace node's dom with its cod in scan
return (scan[:off] + scan_insert + scan[off + len(box.dom):],
box_ind, input_nouns)

def _make_space_for_frame(self,
scan: list[int],
off: int,
Expand Down Expand Up @@ -949,6 +1065,20 @@ def calculate_bounds(self) -> tuple[float, float, float, float]:

return min(all_xs), min(all_ys), max(all_xs), max(all_ys)

def get_noun_id(self) -> int:
"""Generate a new numerical ID for the noun box/wire.

Returns
-------
noun_id : int
returns a unique identifier for the noun wire/box
"""

# Increment and return the next available ID
noun_id = self.noun_id_counter
self.noun_id_counter += 1
return noun_id

@classmethod
def from_diagram(cls,
diagram: grammar.Diagram,
Expand Down Expand Up @@ -976,12 +1106,19 @@ def from_diagram(cls,
drawable = cls()

scan = []
# Generate unique noun_ids for input wires
num_input = len(diagram.dom)
input_nouns = []
for _ in range(num_input):
new_color = drawable.get_noun_id()
input_nouns.append(new_color)

for i, obj in enumerate(diagram.dom):
wire_end = WireEndpoint(WireEndpointType.INPUT,
obj=obj,
x=X_SPACING * i,
y=1)
y=1,
noun_id=input_nouns[i])
wire_end_idx = drawable._add_wire_end(wire_end)
scan.append(wire_end_idx)

Expand All @@ -993,7 +1130,8 @@ def from_diagram(cls,
# TODO: Debug issues with y coord
x, y = drawable._make_space(scan, box, off, foliated=foliated)

scan, box_ind = drawable._add_box(scan, box, off, x, y)
scan, box_ind, input_nouns = drawable._add_box_with_nouns(
scan, box, off, x, y, input_nouns)
box_height = BOX_HEIGHT
# Add drawables for the inside of the frame
if isinstance(box, grammar.Frame):
Expand All @@ -1004,12 +1142,23 @@ def from_diagram(cls,
max_box_half_height = max(max_box_half_height, (box_height / 2))
min_y = min(min_y, y)

num_output = len(diagram.cod)
# Match output nouns with input nouns as much as possible
for i in range(num_output):
if i < len(input_nouns):
pass
else:
# If we run out of input nouns, generate new ones
new_color = drawable.get_noun_id()
input_nouns.append(new_color)

for i, obj in enumerate(diagram.cod):
wire_end = WireEndpoint(
WireEndpointType.OUTPUT,
obj=obj,
x=drawable.wire_endpoints[scan[i]].x,
y=min_y - max_box_half_height - 1.5 * BOX_HEIGHT
y=min_y - max_box_half_height - 1.5 * BOX_HEIGHT,
noun_id=input_nouns[i]
)
wire_end_idx = drawable._add_wire_end(wire_end)
drawable._add_wire(scan[i], wire_end_idx)
Expand Down Expand Up @@ -1384,6 +1533,7 @@ def _merge_with(self, drawable: 'DrawableDiagramWithFrames') -> None:
last_wire_endpoint = len(self.wire_endpoints)

for wire_endpoint in drawable.wire_endpoints:
wire_endpoint.noun_id = 0
self.wire_endpoints.append(wire_endpoint)

for box in drawable.boxes:
Expand Down
29 changes: 24 additions & 5 deletions lambeq/backend/drawing/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
DEFAULT_ASPECT,
DEFAULT_MARGINS,
DrawingBackend,
FRAME_COLORS)
FRAME_COLORS,
WIRE_COLORS)
from lambeq.backend.drawing.helpers import drawn_as_spider, needs_asymmetry
from lambeq.backend.drawing.mat_backend import MatBackend
from lambeq.backend.drawing.text_printer import PregroupTextPrinter
Expand Down Expand Up @@ -110,6 +111,9 @@ def draw(diagram: Diagram, **params) -> None:
params['coloring_mode'] = params.get(
'coloring_mode', ColoringMode.TYPE.value,
)
params['color_wires'] = params.get(
'color_wires', diagram.has_frames,
)
if drawable is None:
drawable = drawable_cls.from_diagram(diagram,
params.get('foliated', False))
Expand All @@ -124,7 +128,8 @@ def draw(diagram: Diagram, **params) -> None:
backend = TikzBackend(
use_tikzstyles=params.get('use_tikzstyles', None))
else:
backend = MatBackend(figsize=params.get('figsize', None))
backend = MatBackend(figsize=params.get('figsize', None),
wires_linewidth=params.get('wires_width', 1.25))

min_size = 0.01
max_v = max([v for point in ([point.coordinates for point in
Expand Down Expand Up @@ -460,6 +465,14 @@ def _get_box_color(box: grammar.Diagrammable,
return color


def _get_wire_color(wire_id):
if wire_id == 0:
return '#000000'
else:
wire_color = WIRE_COLORS[(wire_id - 1) % len(WIRE_COLORS)]
return wire_color


def _draw_pregroup_state(backend: DrawingBackend,
drawable_box: BoxNode,
**params) -> DrawingBackend:
Expand Down Expand Up @@ -524,9 +537,15 @@ def _draw_wires(backend: DrawingBackend,
for src_idx, tgt_idx in drawable_diagram.wires:
source = drawable_diagram.wire_endpoints[src_idx]
target = drawable_diagram.wire_endpoints[tgt_idx]

backend.draw_wire(
source.coordinates, target.coordinates)
wire_color_id = 0
if params.get('color_wires'):
# Determine the color based on the type of the source
if source.kind in {WireEndpointType.INPUT}:
wire_color_id = source.noun_id
else:
wire_color_id = target.noun_id
backend.draw_wire(source.coordinates, target.coordinates,
color_id=wire_color_id, **params)

if (params.get('draw_type_labels', True) and source.kind in
{WireEndpointType.INPUT, WireEndpointType.COD}):
Expand Down
35 changes: 34 additions & 1 deletion lambeq/backend/drawing/drawing_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@
DEFAULT_MARGINS = (.05, .1)
DEFAULT_ASPECT = 'equal'

WIRE_COLORS = [
'#9c540e', '#f4a940', '#066ee2', '#d03b2d', '#7fd68b',
'#574cfa', '#49a141', '#a629b3', '#271296', '#ff6347',
'#adff2f', '#7446f2', '#007765', '#b60539', '#ff00ff',
'#c330b9', '#73b8fd', '#ff1493', '#00bfff', '#ffb6c1',
'#740127', '#e2074c', '#0252a1', '#fea431', '#205356',
'#450d06', '#d17800', '#3831a0', '#ff4500', '#d8bfd8'
]

FRAME_COLORS: list[str] = [
'#fbe8e7', '#fee1ba', '#fff9e5', '#e8f8ea', '#dcfbf5',
Expand Down Expand Up @@ -146,7 +154,9 @@ def draw_wire(self,
bend_out: bool = False,
bend_in: bool = False,
is_leg: bool = False,
style: str | None = None) -> None:
style: str | None = None,
color_id: int = 0,
**params) -> None:
"""
Draws a wire from source to target, possibly with a curve

Expand Down Expand Up @@ -184,6 +194,29 @@ def draw_spiders(self, drawable: DrawableDiagram, **params) -> None:

"""

def _get_wire_color(self, wire_id : int, **params) -> str:
"""
Retrieves a color that uniquely represent a given wire ID.

Parameters
----------
wire_id : int
The noun identifier of the wire for which the color is
being retrieved.
**params:
Additional parameters.

Returns:
wire_color : str
The Hex color of the wire, represented as a string.

"""
if not params.get('color_wires') or wire_id == 0:
return '#000000'
else:
wire_color = WIRE_COLORS[(wire_id - 1) % len(WIRE_COLORS)]
return wire_color

@abstractmethod
def output(self,
path: str | None = None,
Expand Down
Loading
Loading