Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to colour wires in diagrams with frames #177

Merged
merged 44 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
14b8f85
Color wires changes. wires are coloured used by default when the diag…
Ragunath1729 Nov 7, 2024
ba2230b
Merge remote-tracking branch 'refs/remotes/public-lambeq/main' into b…
Ragunath1729 Nov 7, 2024
0b3f275
review comments fixes, remove colors looks like similar
Ragunath1729 Nov 8, 2024
84b4456
rearrange colors to avoid similar shades comes together
Ragunath1729 Nov 8, 2024
8b963d7
lint error fixes, flake8 fixes
Ragunath1729 Nov 8, 2024
885c9cb
flake8 fixes
Ragunath1729 Nov 8, 2024
72b635f
inherit add_box to avoid side-effect in non-frame diagrams
Ragunath1729 Nov 11, 2024
9dbcc40
remove debug logs
Ragunath1729 Nov 11, 2024
3a8f66d
add color params to draw method in Tikzit backend
Ragunath1729 Nov 11, 2024
3ef64a5
lint error fixes
Ragunath1729 Nov 11, 2024
1a1f2a0
mypy suggestion fixes
Ragunath1729 Nov 11, 2024
94f6dbd
demo on simple-book notebook
Ragunath1729 Nov 11, 2024
d89ec0c
change signature to fix flake8 error
Ragunath1729 Nov 11, 2024
8093b95
add doc strings to new functions
Ragunath1729 Nov 11, 2024
69d9830
make noun_id non-static variable
Ragunath1729 Nov 12, 2024
3af9164
remove different colours and remove colours with similar shades
Ragunath1729 Nov 12, 2024
6bf61c8
remove yellow shades from colors list
Ragunath1729 Nov 13, 2024
222376a
remove optional for input_nouns
Ragunath1729 Nov 13, 2024
d1a7100
different linewidth for box and wires, get from user using params
Ragunath1729 Nov 13, 2024
8da64d7
remove invalid characters from stylenames which throw render errors
Ragunath1729 Nov 14, 2024
3c83b83
color wire changes in Tikz_backend
Ragunath1729 Nov 14, 2024
e8c04f3
remove extra {}
Ragunath1729 Nov 15, 2024
7caef28
static commands for tex file, new layer for labels
Ragunath1729 Nov 15, 2024
3d9be55
flake8 suggestion to remove quotes
Ragunath1729 Nov 15, 2024
57fe2d9
remove extra {}
Ragunath1729 Nov 18, 2024
edac6cd
delete sample file
Ragunath1729 Nov 18, 2024
353df2f
cosmetic readability fixes
Ragunath1729 Nov 19, 2024
8d43b79
refactor wire colour formatting like boxes
Ragunath1729 Nov 19, 2024
78fbfee
wires width changes for Tikz backend
Ragunath1729 Nov 19, 2024
ba8dc7f
fix for flake8 import order error
Ragunath1729 Nov 19, 2024
ae69857
update testcases for frame drawing with wire colouring
Ragunath1729 Nov 20, 2024
1f11612
Merge branch 'main' into br_color_wires
Ragunath1729 Nov 20, 2024
c829e4f
change label flag as hidden by default
Ragunath1729 Nov 20, 2024
99ed5e4
Fix code formatting
neiljdo Nov 21, 2024
5ee51b2
Update constants
neiljdo Nov 21, 2024
294cd72
Update tests for Tikz circuit
neiljdo Nov 21, 2024
02b35f3
Update tests for Tikz diagrams
neiljdo Nov 21, 2024
697a0ee
add box_linewidth to Tikz backend
Ragunath1729 Nov 21, 2024
5d18c22
update frame-drawing tests with linewidth changes
Ragunath1729 Nov 21, 2024
93a84bb
update test-drawing tests with linewidth changes
Ragunath1729 Nov 21, 2024
0f3c286
update circuit-drawing tests with linewidth changes
Ragunath1729 Nov 21, 2024
e328cd8
Merge remote-tracking branch 'ragu/br_color_wires' into br_color_wires
Ragunath1729 Nov 21, 2024
448107b
make mypy happy
Ragunath1729 Nov 21, 2024
9f3a9a8
Add `box_linewidth` as param for `MatBackend`
neiljdo Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 81 additions & 11 deletions lambeq/backend/drawing/drawable.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class WireEndpoint:

x: float
y: float

color: int = 0 # New attribute for wire noun
parent: Optional['BoxNode'] = None

@property
Expand Down Expand Up @@ -345,7 +345,8 @@ def _add_box(self,
box: grammar.Box,
off: int,
x_pos: float,
y_pos: float) -> tuple[list[int], int]:
y_pos: float,
input_nouns: list[str]=None) -> tuple[list[int], int, list[str]]:
"""Add a box to the graph, creating necessary wire endpoints.

Returns
Expand All @@ -354,25 +355,63 @@ def _add_box(self,
The new scan of wire endpoints after adding the box
box_ind : int
The index of the newly added `BoxNode`
list[int]
The new order of input_nouns after adding the box
"""
node = BoxNode(box, x_pos, y_pos)

box_ind = self._add_boxnode(node)
num_input = len(box.dom)
for i in range(num_input):
if i < len(input_nouns):
pass
else:
# If we run out of input nouns, generate new ones
new_color = self.get_noun_id()
input_nouns.append(new_color)

# Create a node representing each element in the box's domain
for i, obj in enumerate(box.dom):
idx = off + i
nbr_idx = scan[off + i]
noun_id = input_nouns[idx] if input_nouns and idx < len(
input_nouns) else DrawableDiagramWithFrames.get_noun_id() # Default to black if no input color available

wire_end = WireEndpoint(WireEndpointType.DOM,
obj=obj,
x=self.wire_endpoints[nbr_idx].x,
y=y_pos + HALF_BOX_HEIGHT)
y=y_pos + HALF_BOX_HEIGHT,
color=noun_id)

wire_idx = self._add_wire_end(wire_end)
node.add_dom_wire(wire_idx)
self._add_wire(nbr_idx, wire_idx)

scan_insert = []

# if Swap, exchange the noun_ids
if isinstance(box, grammar.Swap):
if input_nouns and len(box.dom) > 1:
dom_idx_1 = off
dom_idx_2 = off + 1
input_nouns[dom_idx_1], input_nouns[dom_idx_2] = input_nouns[dom_idx_2], input_nouns[dom_idx_1]
# if Spider, expand or shrink the noun_ids based on type
elif isinstance(node.obj, grammar.Spider):
if len(box.dom) == 1 and len(box.cod) > 1:
dom_noun = input_nouns[off] if input_nouns and off < len(input_nouns) else DrawableDiagramWithFrames.get_noun_id()
expanded_colors = [dom_noun] * len(box.cod)
input_nouns = input_nouns[:off] + expanded_colors + input_nouns[off + len(box.dom):]
elif len(box.dom) > 1 and len(box.cod) == 1:
cod_noun = input_nouns[off] if input_nouns and off < len(input_nouns) else DrawableDiagramWithFrames.get_noun_id()
input_nouns = input_nouns[:off] + [cod_noun] + input_nouns[off + len(box.dom):]

num_output = off+len(box.cod)
for i in range(num_output):
if i < len(input_nouns):
pass
else:
# If we run out of input nouns, generate new ones
new_color = self.get_noun_id()
input_nouns.append(new_color)
# Create a node representing each element in the box's codomain
for i, obj in enumerate(box.cod):

Expand All @@ -383,18 +422,20 @@ def _add_box(self,
else:
x = x_pos + X_SPACING * (i - len(box.cod[1:]) / 2)
y = y_pos - HALF_BOX_HEIGHT

idx = off + i
noun_id = input_nouns[idx] if input_nouns and idx < len(input_nouns) else DrawableDiagramWithFrames.get_noun_id()
wire_end = WireEndpoint(WireEndpointType.COD,
obj=obj,
x=x,
y=y)
y=y,
color=noun_id)

wire_idx = self._add_wire_end(wire_end)
scan_insert.append(wire_idx)
node.add_cod_wire(wire_idx)

# Replace node's dom with its cod in scan
return scan[:off] + scan_insert + scan[off + len(box.dom):], box_ind
return scan[:off] + scan_insert + scan[off + len(box.dom):], box_ind, input_nouns

def _find_box_edges(self,
box: grammar.Box,
Expand Down Expand Up @@ -778,7 +819,8 @@ class DrawableDiagramWithFrames(DrawableDiagram):
frame, carrying all information necessary to render it.

"""

#add counter for Nouns
noun_id_counter = 0
def _make_space(self,
scan: list[int],
box: grammar.Box,
Expand Down Expand Up @@ -949,6 +991,15 @@ def calculate_bounds(self) -> tuple[float, float, float, float]:

return min(all_xs), min(all_ys), max(all_xs), max(all_ys)


@staticmethod
def get_noun_id() -> int:
"""Generate a new numerical ID for the noun box/wire."""
# Increment and return the next available ID
noun_id = DrawableDiagramWithFrames.noun_id_counter
DrawableDiagramWithFrames.noun_id_counter += 1
return noun_id

@classmethod
def from_diagram(cls,
diagram: grammar.Diagram,
Expand Down Expand Up @@ -976,12 +1027,19 @@ def from_diagram(cls,
drawable = cls()

scan = []
# Generate unique noun_ids for input wires
num_input = len(diagram.dom)
input_nouns = []
for i in range(num_input):
new_color = drawable.get_noun_id()
input_nouns.append(new_color)

for i, obj in enumerate(diagram.dom):
wire_end = WireEndpoint(WireEndpointType.INPUT,
obj=obj,
x=X_SPACING * i,
y=1)
y=1,
color=input_nouns[i])
wire_end_idx = drawable._add_wire_end(wire_end)
scan.append(wire_end_idx)

Expand All @@ -993,7 +1051,7 @@ def from_diagram(cls,
# TODO: Debug issues with y coord
x, y = drawable._make_space(scan, box, off, foliated=foliated)

scan, box_ind = drawable._add_box(scan, box, off, x, y)
scan, box_ind, input_nouns = drawable._add_box(scan, box, off, x, y, input_nouns)
box_height = BOX_HEIGHT
# Add drawables for the inside of the frame
if isinstance(box, grammar.Frame):
Expand All @@ -1004,12 +1062,24 @@ def from_diagram(cls,
max_box_half_height = max(max_box_half_height, (box_height / 2))
min_y = min(min_y, y)

num_output = len(diagram.cod)
output_nouns = []
# Match output nouns with input nouns as much as possible
for i in range(num_output):
if i < len(input_nouns):
pass
else:
# If we run out of input nouns, generate new ones
new_color = drawable.get_noun_id()
input_nouns.append(new_color)

for i, obj in enumerate(diagram.cod):
wire_end = WireEndpoint(
WireEndpointType.OUTPUT,
obj=obj,
x=drawable.wire_endpoints[scan[i]].x,
y=min_y - max_box_half_height - 1.5 * BOX_HEIGHT
y=min_y - max_box_half_height - 1.5 * BOX_HEIGHT,
color=input_nouns[i]
)
wire_end_idx = drawable._add_wire_end(wire_end)
drawable._add_wire(scan[i], wire_end_idx)
Expand Down
64 changes: 58 additions & 6 deletions lambeq/backend/drawing/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
DEFAULT_ASPECT,
DEFAULT_MARGINS,
DrawingBackend,
FRAME_COLORS)
FRAME_COLORS,
WIRE_COLORS)
from lambeq.backend.drawing.helpers import drawn_as_spider, needs_asymmetry
from lambeq.backend.drawing.mat_backend import MatBackend
from lambeq.backend.drawing.text_printer import PregroupTextPrinter
Expand Down Expand Up @@ -110,6 +111,9 @@ def draw(diagram: Diagram, **params) -> None:
params['coloring_mode'] = params.get(
'coloring_mode', ColoringMode.TYPE,
)
params['color_wires'] = params.get(
'color_wires', diagram.has_frames,
)
if drawable is None:
drawable = drawable_cls.from_diagram(diagram,
params.get('foliated', False))
Expand Down Expand Up @@ -144,17 +148,56 @@ def draw(diagram: Diagram, **params) -> None:
backend = _draw_controlled_gate(backend, drawable, node, **params)
elif not drawn_as_spider(node.obj):
backend = _draw_box(backend, drawable, node, **params)

else:
wire_drawings = compute_spider_wires(node, drawable, **params)
backend.draw_spider(node,wire_drawings,**params)
# Draw boxes first since they are filled
backend = _draw_wires(backend, drawable, **params)
backend.draw_spiders(drawable, **params)
backend.output(
path=params.get('path', None),
baseline=0,
tikz_options=params.get('tikz_options', None),
show=params.get('show', True),
margins=params.get('margins', DEFAULT_MARGINS))

def compute_spider_wires(node, drawable, **params)-> list:
"""
Compute the wire drawing requirements for spider.

Computes all the wires (dom and cod) that need to be drawn for spider
and returns a list of wires.
"""
wire_color = 'black'
wire_drawings = []
# Compute wires for cod_wires (outgoing wires)
for wire in node.cod_wires:
start_coordinates = node.coordinates
end_coordinates = drawable.wire_endpoints[wire].coordinates
if params['color_wires']:
wire_color = WIRE_COLORS[(drawable.wire_endpoints[wire].color - 1) % len(WIRE_COLORS)]
wire_drawings.append({
'start': start_coordinates,
'end': end_coordinates,
'bend_out': True,
'is_leg': True,
'color': wire_color
})

# Compute wires for dom_wires (incoming wires)
for wire in node.dom_wires:
start_coordinates = drawable.wire_endpoints[wire].coordinates
end_coordinates = node.coordinates
if params['color_wires']:
wire_color = WIRE_COLORS[(drawable.wire_endpoints[wire].color - 1) % len(WIRE_COLORS)]
wire_drawings.append({
'start': start_coordinates,
'end': end_coordinates,
'bend_in': True,
'is_leg': True,
'color': wire_color
})

return wire_drawings

def draw_pregroup(diagram: Diagram, **params) -> None:
""" Draw a pregroup grammar diagram.
Expand Down Expand Up @@ -458,7 +501,9 @@ def _get_box_color(box: grammar.Diagrammable,
color = FRAME_COLORS[(frame_attr - 1) % len(FRAME_COLORS)]

return color

def _get_wire_color(wire_id):
wire_color = WIRE_COLORS[(wire_id - 1) % len(WIRE_COLORS)]
return wire_color

def _draw_pregroup_state(backend: DrawingBackend,
drawable_box: BoxNode,
Expand Down Expand Up @@ -524,9 +569,16 @@ def _draw_wires(backend: DrawingBackend,
for src_idx, tgt_idx in drawable_diagram.wires:
source = drawable_diagram.wire_endpoints[src_idx]
target = drawable_diagram.wire_endpoints[tgt_idx]

wire_color = 'black'
if params['color_wires']:
# Determine the color based on the type of the source
if source.kind in {WireEndpointType.INPUT}:
wire_color_id = source.color
else:
wire_color_id = target.color
wire_color = _get_wire_color(wire_color_id)
backend.draw_wire(
source.coordinates, target.coordinates)
source.coordinates, target.coordinates, color=wire_color)

if (params.get('draw_type_labels', True) and source.kind in
{WireEndpointType.INPUT, WireEndpointType.COD}):
Expand Down
28 changes: 27 additions & 1 deletion lambeq/backend/drawing/drawing_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@
DEFAULT_MARGINS = (.05, .1)
DEFAULT_ASPECT = 'equal'

WIRE_COLORS = [
"#A22417", "#D17800", "#F4A940", "#49A141", "#007765",
"#398889", "#0252A1", "#3831A0","#A629B3", "#B60539",
"#73190E", "#BE660F", "#F29A3B", "#3C8331", "#01594C",
"#2F7173", "#013B73", "#292372", "#271296", "#9526AF",
"#960131", "#450D06", "#9C540E", "#F07E33", "#003C33",
"#205356", "#002B54", "#201B5B", "#1D0B66", "#751FA5"
]


FRAME_COLORS: list[str] = [
'#fbe8e7', '#fee1ba', '#fff9e5', '#e8f8ea', '#dcfbf5',
Expand Down Expand Up @@ -146,7 +155,8 @@ def draw_wire(self,
bend_out: bool = False,
bend_in: bool = False,
is_leg: bool = False,
style: str | None = None) -> None:
style: str | None = None,
color: str = 'black') -> None:
"""
Draws a wire from source to target, possibly with a curve

Expand Down Expand Up @@ -184,6 +194,22 @@ def draw_spiders(self, drawable: DrawableDiagram, **params) -> None:

"""

@abstractmethod
def draw_spider(self, node, wire_drawings, **params) -> None:
"""
Draws the spider node with the list of wires.

Parameters
----------
node: DrawableDiagram
the node to be drawn.
wire_drawings: list of wires
list of wires in the spider
params: any
Additional parameters.

"""

@abstractmethod
def output(self,
path: str | None = None,
Expand Down
21 changes: 18 additions & 3 deletions lambeq/backend/drawing/mat_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@ def draw_wire(self,
bend_out: bool = False,
bend_in: bool = False,
is_leg: bool = False,
style: str | None = None) -> None:
style: str | None = None,
color: str = 'black') -> None:
if style == '->':
self.axis.arrow(
*(source + (target[0] - source[0], target[1] - source[1])),
head_width=.02, color='black')
head_width=.02, color=color)
else:
if is_leg:
mid = (target[0], source[1])
Expand Down Expand Up @@ -107,7 +108,7 @@ def draw_wire(self,
])

self.axis.add_patch(PathPatch(
path, facecolor='none', linewidth=self.linewidth))
path, facecolor='none', linewidth=self.linewidth, edgecolor=color))

self.max_width = max(self.max_width, source[0], target[0])

Expand All @@ -128,6 +129,20 @@ def draw_spiders(self, drawable: DrawableDiagram, **params) -> None:
bend_in=True,
is_leg=True)

def draw_spider(self, node, wire_drawings, **params) -> None:
if isinstance(node.obj, Spider):
self.draw_node(*node.coordinates, **params)

for wire in wire_drawings:
self.draw_wire(
wire['start'],
wire['end'],
bend_out=wire.get('bend_out', False),
bend_in=wire.get('bend_in', False),
is_leg=wire['is_leg'],
color=wire['color']
)

def output(self,
path: str | None = None,
show: bool = True,
Expand Down
Loading