Skip to content

Commit

Permalink
Color wires changes. wires are coloured used by default when the diag…
Browse files Browse the repository at this point in the history
…ram has frames but can be turned off using flag `color_wires=False`
  • Loading branch information
Ragunath1729 committed Nov 7, 2024
1 parent 9e2aa45 commit 14b8f85
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 21 deletions.
92 changes: 81 additions & 11 deletions lambeq/backend/drawing/drawable.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class WireEndpoint:

x: float
y: float

color: int = 0 # New attribute for wire noun
parent: Optional['BoxNode'] = None

@property
Expand Down Expand Up @@ -345,7 +345,8 @@ def _add_box(self,
box: grammar.Box,
off: int,
x_pos: float,
y_pos: float) -> tuple[list[int], int]:
y_pos: float,
input_nouns: list[str]=None) -> tuple[list[int], int, list[str]]:
"""Add a box to the graph, creating necessary wire endpoints.
Returns
Expand All @@ -354,25 +355,63 @@ def _add_box(self,
The new scan of wire endpoints after adding the box
box_ind : int
The index of the newly added `BoxNode`
list[int]
The new order of input_nouns after adding the box
"""
node = BoxNode(box, x_pos, y_pos)

box_ind = self._add_boxnode(node)
num_input = len(box.dom)
for i in range(num_input):
if i < len(input_nouns):
pass
else:
# If we run out of input nouns, generate new ones
new_color = self.get_noun_id()
input_nouns.append(new_color)

# Create a node representing each element in the box's domain
for i, obj in enumerate(box.dom):
idx = off + i
nbr_idx = scan[off + i]
noun_id = input_nouns[idx] if input_nouns and idx < len(
input_nouns) else DrawableDiagramWithFrames.get_noun_id() # Default to black if no input color available

wire_end = WireEndpoint(WireEndpointType.DOM,
obj=obj,
x=self.wire_endpoints[nbr_idx].x,
y=y_pos + HALF_BOX_HEIGHT)
y=y_pos + HALF_BOX_HEIGHT,
color=noun_id)

wire_idx = self._add_wire_end(wire_end)
node.add_dom_wire(wire_idx)
self._add_wire(nbr_idx, wire_idx)

scan_insert = []

# if Swap, exchange the noun_ids
if isinstance(box, grammar.Swap):
if input_nouns and len(box.dom) > 1:
dom_idx_1 = off
dom_idx_2 = off + 1
input_nouns[dom_idx_1], input_nouns[dom_idx_2] = input_nouns[dom_idx_2], input_nouns[dom_idx_1]
# if Spider, expand or shrink the noun_ids based on type
elif isinstance(node.obj, grammar.Spider):
if len(box.dom) == 1 and len(box.cod) > 1:
dom_noun = input_nouns[off] if input_nouns and off < len(input_nouns) else DrawableDiagramWithFrames.get_noun_id()
expanded_colors = [dom_noun] * len(box.cod)
input_nouns = input_nouns[:off] + expanded_colors + input_nouns[off + len(box.dom):]
elif len(box.dom) > 1 and len(box.cod) == 1:
cod_noun = input_nouns[off] if input_nouns and off < len(input_nouns) else DrawableDiagramWithFrames.get_noun_id()
input_nouns = input_nouns[:off] + [cod_noun] + input_nouns[off + len(box.dom):]

num_output = off+len(box.cod)
for i in range(num_output):
if i < len(input_nouns):
pass
else:
# If we run out of input nouns, generate new ones
new_color = self.get_noun_id()
input_nouns.append(new_color)
# Create a node representing each element in the box's codomain
for i, obj in enumerate(box.cod):

Expand All @@ -383,18 +422,20 @@ def _add_box(self,
else:
x = x_pos + X_SPACING * (i - len(box.cod[1:]) / 2)
y = y_pos - HALF_BOX_HEIGHT

idx = off + i
noun_id = input_nouns[idx] if input_nouns and idx < len(input_nouns) else DrawableDiagramWithFrames.get_noun_id()
wire_end = WireEndpoint(WireEndpointType.COD,
obj=obj,
x=x,
y=y)
y=y,
color=noun_id)

wire_idx = self._add_wire_end(wire_end)
scan_insert.append(wire_idx)
node.add_cod_wire(wire_idx)

# Replace node's dom with its cod in scan
return scan[:off] + scan_insert + scan[off + len(box.dom):], box_ind
return scan[:off] + scan_insert + scan[off + len(box.dom):], box_ind, input_nouns

def _find_box_edges(self,
box: grammar.Box,
Expand Down Expand Up @@ -778,7 +819,8 @@ class DrawableDiagramWithFrames(DrawableDiagram):
frame, carrying all information necessary to render it.
"""

#add counter for Nouns
noun_id_counter = 0
def _make_space(self,
scan: list[int],
box: grammar.Box,
Expand Down Expand Up @@ -949,6 +991,15 @@ def calculate_bounds(self) -> tuple[float, float, float, float]:

return min(all_xs), min(all_ys), max(all_xs), max(all_ys)


@staticmethod
def get_noun_id() -> int:
"""Generate a new numerical ID for the noun box/wire."""
# Increment and return the next available ID
noun_id = DrawableDiagramWithFrames.noun_id_counter
DrawableDiagramWithFrames.noun_id_counter += 1
return noun_id

@classmethod
def from_diagram(cls,
diagram: grammar.Diagram,
Expand Down Expand Up @@ -976,12 +1027,19 @@ def from_diagram(cls,
drawable = cls()

scan = []
# Generate unique noun_ids for input wires
num_input = len(diagram.dom)
input_nouns = []
for i in range(num_input):
new_color = drawable.get_noun_id()
input_nouns.append(new_color)

for i, obj in enumerate(diagram.dom):
wire_end = WireEndpoint(WireEndpointType.INPUT,
obj=obj,
x=X_SPACING * i,
y=1)
y=1,
color=input_nouns[i])
wire_end_idx = drawable._add_wire_end(wire_end)
scan.append(wire_end_idx)

Expand All @@ -993,7 +1051,7 @@ def from_diagram(cls,
# TODO: Debug issues with y coord
x, y = drawable._make_space(scan, box, off, foliated=foliated)

scan, box_ind = drawable._add_box(scan, box, off, x, y)
scan, box_ind, input_nouns = drawable._add_box(scan, box, off, x, y, input_nouns)
box_height = BOX_HEIGHT
# Add drawables for the inside of the frame
if isinstance(box, grammar.Frame):
Expand All @@ -1004,12 +1062,24 @@ def from_diagram(cls,
max_box_half_height = max(max_box_half_height, (box_height / 2))
min_y = min(min_y, y)

num_output = len(diagram.cod)
output_nouns = []
# Match output nouns with input nouns as much as possible
for i in range(num_output):
if i < len(input_nouns):
pass
else:
# If we run out of input nouns, generate new ones
new_color = drawable.get_noun_id()
input_nouns.append(new_color)

for i, obj in enumerate(diagram.cod):
wire_end = WireEndpoint(
WireEndpointType.OUTPUT,
obj=obj,
x=drawable.wire_endpoints[scan[i]].x,
y=min_y - max_box_half_height - 1.5 * BOX_HEIGHT
y=min_y - max_box_half_height - 1.5 * BOX_HEIGHT,
color=input_nouns[i]
)
wire_end_idx = drawable._add_wire_end(wire_end)
drawable._add_wire(scan[i], wire_end_idx)
Expand Down
64 changes: 58 additions & 6 deletions lambeq/backend/drawing/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
DEFAULT_ASPECT,
DEFAULT_MARGINS,
DrawingBackend,
FRAME_COLORS)
FRAME_COLORS,
WIRE_COLORS)
from lambeq.backend.drawing.helpers import drawn_as_spider, needs_asymmetry
from lambeq.backend.drawing.mat_backend import MatBackend
from lambeq.backend.drawing.text_printer import PregroupTextPrinter
Expand Down Expand Up @@ -110,6 +111,9 @@ def draw(diagram: Diagram, **params) -> None:
params['coloring_mode'] = params.get(
'coloring_mode', ColoringMode.TYPE,
)
params['color_wires'] = params.get(
'color_wires', diagram.has_frames,
)
if drawable is None:
drawable = drawable_cls.from_diagram(diagram,
params.get('foliated', False))
Expand Down Expand Up @@ -144,17 +148,56 @@ def draw(diagram: Diagram, **params) -> None:
backend = _draw_controlled_gate(backend, drawable, node, **params)
elif not drawn_as_spider(node.obj):
backend = _draw_box(backend, drawable, node, **params)

else:
wire_drawings = compute_spider_wires(node, drawable, **params)
backend.draw_spider(node,wire_drawings,**params)
# Draw boxes first since they are filled
backend = _draw_wires(backend, drawable, **params)
backend.draw_spiders(drawable, **params)
backend.output(
path=params.get('path', None),
baseline=0,
tikz_options=params.get('tikz_options', None),
show=params.get('show', True),
margins=params.get('margins', DEFAULT_MARGINS))

def compute_spider_wires(node, drawable, **params)-> list:
"""
Compute the wire drawing requirements for spider.
Computes all the wires (dom and cod) that need to be drawn for spider
and returns a list of wires.
"""
wire_color = 'black'
wire_drawings = []
# Compute wires for cod_wires (outgoing wires)
for wire in node.cod_wires:
start_coordinates = node.coordinates
end_coordinates = drawable.wire_endpoints[wire].coordinates
if params['color_wires']:
wire_color = WIRE_COLORS[(drawable.wire_endpoints[wire].color - 1) % len(WIRE_COLORS)]
wire_drawings.append({
'start': start_coordinates,
'end': end_coordinates,
'bend_out': True,
'is_leg': True,
'color': wire_color
})

# Compute wires for dom_wires (incoming wires)
for wire in node.dom_wires:
start_coordinates = drawable.wire_endpoints[wire].coordinates
end_coordinates = node.coordinates
if params['color_wires']:
wire_color = WIRE_COLORS[(drawable.wire_endpoints[wire].color - 1) % len(WIRE_COLORS)]
wire_drawings.append({
'start': start_coordinates,
'end': end_coordinates,
'bend_in': True,
'is_leg': True,
'color': wire_color
})

return wire_drawings

def draw_pregroup(diagram: Diagram, **params) -> None:
""" Draw a pregroup grammar diagram.
Expand Down Expand Up @@ -458,7 +501,9 @@ def _get_box_color(box: grammar.Diagrammable,
color = FRAME_COLORS[(frame_attr - 1) % len(FRAME_COLORS)]

return color

def _get_wire_color(wire_id):
wire_color = WIRE_COLORS[(wire_id - 1) % len(WIRE_COLORS)]
return wire_color

def _draw_pregroup_state(backend: DrawingBackend,
drawable_box: BoxNode,
Expand Down Expand Up @@ -524,9 +569,16 @@ def _draw_wires(backend: DrawingBackend,
for src_idx, tgt_idx in drawable_diagram.wires:
source = drawable_diagram.wire_endpoints[src_idx]
target = drawable_diagram.wire_endpoints[tgt_idx]

wire_color = 'black'
if params['color_wires']:
# Determine the color based on the type of the source
if source.kind in {WireEndpointType.INPUT}:
wire_color_id = source.color
else:
wire_color_id = target.color
wire_color = _get_wire_color(wire_color_id)
backend.draw_wire(
source.coordinates, target.coordinates)
source.coordinates, target.coordinates, color=wire_color)

if (params.get('draw_type_labels', True) and source.kind in
{WireEndpointType.INPUT, WireEndpointType.COD}):
Expand Down
28 changes: 27 additions & 1 deletion lambeq/backend/drawing/drawing_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@
DEFAULT_MARGINS = (.05, .1)
DEFAULT_ASPECT = 'equal'

WIRE_COLORS = [
"#A22417", "#D17800", "#F4A940", "#49A141", "#007765",
"#398889", "#0252A1", "#3831A0","#A629B3", "#B60539",
"#73190E", "#BE660F", "#F29A3B", "#3C8331", "#01594C",
"#2F7173", "#013B73", "#292372", "#271296", "#9526AF",
"#960131", "#450D06", "#9C540E", "#F07E33", "#003C33",
"#205356", "#002B54", "#201B5B", "#1D0B66", "#751FA5"
]


FRAME_COLORS: list[str] = [
'#fbe8e7', '#fee1ba', '#fff9e5', '#e8f8ea', '#dcfbf5',
Expand Down Expand Up @@ -146,7 +155,8 @@ def draw_wire(self,
bend_out: bool = False,
bend_in: bool = False,
is_leg: bool = False,
style: str | None = None) -> None:
style: str | None = None,
color: str = 'black') -> None:
"""
Draws a wire from source to target, possibly with a curve
Expand Down Expand Up @@ -184,6 +194,22 @@ def draw_spiders(self, drawable: DrawableDiagram, **params) -> None:
"""

@abstractmethod
def draw_spider(self, node, wire_drawings, **params) -> None:
"""
Draws the spider node with the list of wires.
Parameters
----------
node: DrawableDiagram
the node to be drawn.
wire_drawings: list of wires
list of wires in the spider
params: any
Additional parameters.
"""

@abstractmethod
def output(self,
path: str | None = None,
Expand Down
21 changes: 18 additions & 3 deletions lambeq/backend/drawing/mat_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@ def draw_wire(self,
bend_out: bool = False,
bend_in: bool = False,
is_leg: bool = False,
style: str | None = None) -> None:
style: str | None = None,
color: str = 'black') -> None:
if style == '->':
self.axis.arrow(
*(source + (target[0] - source[0], target[1] - source[1])),
head_width=.02, color='black')
head_width=.02, color=color)
else:
if is_leg:
mid = (target[0], source[1])
Expand Down Expand Up @@ -107,7 +108,7 @@ def draw_wire(self,
])

self.axis.add_patch(PathPatch(
path, facecolor='none', linewidth=self.linewidth))
path, facecolor='none', linewidth=self.linewidth, edgecolor=color))

self.max_width = max(self.max_width, source[0], target[0])

Expand All @@ -128,6 +129,20 @@ def draw_spiders(self, drawable: DrawableDiagram, **params) -> None:
bend_in=True,
is_leg=True)

def draw_spider(self, node, wire_drawings, **params) -> None:
if isinstance(node.obj, Spider):
self.draw_node(*node.coordinates, **params)

for wire in wire_drawings:
self.draw_wire(
wire['start'],
wire['end'],
bend_out=wire.get('bend_out', False),
bend_in=wire.get('bend_in', False),
is_leg=wire['is_leg'],
color=wire['color']
)

def output(self,
path: str | None = None,
show: bool = True,
Expand Down

0 comments on commit 14b8f85

Please sign in to comment.