Skip to content

Commit 3a8f66d

Browse files
committed
add color params to draw method in Tikzit backend
1 parent 9dbcc40 commit 3a8f66d

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

lambeq/backend/drawing/tikz_backend.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,11 @@ def draw_wire(self,
122122
bend_out: bool = False,
123123
bend_in: bool = False,
124124
is_leg: bool = False,
125-
style: str | None = None) -> None:
125+
style: str | None = None,
126+
color_id: int = 0,
127+
**params) -> None:
128+
129+
color = self._get_wire_color(color_id, **params)
126130
out = (-90 if not bend_out or source[0] == target[0]
127131
else (180 if source[0] > target[0] else 0))
128132
inp = (90 if not bend_in or source[0] == target[0]
@@ -178,11 +182,15 @@ def draw_spiders(self, drawable: DrawableDiagram, **params) -> None:
178182
for wire in node.cod_wires:
179183
self.draw_wire(node.coordinates,
180184
drawable.wire_endpoints[wire].coordinates,
181-
bend_out=True)
185+
bend_out=True,
186+
color_id=drawable.wire_endpoints[wire].noun_id,
187+
**params)
182188
for wire in node.dom_wires:
183189
self.draw_wire(drawable.wire_endpoints[wire].coordinates,
184190
node.coordinates,
185-
bend_in=True)
191+
bend_in=True,
192+
color_id=drawable.wire_endpoints[wire].noun_id,
193+
**params)
186194

187195
def output(self, path=None, show=True, **params) -> None:
188196
baseline = params.get('baseline', 0)

0 commit comments

Comments
 (0)