Skip to content

Commit 27e212a

Browse files
authored
Merge pull request #47 from ksunden/introspection
Introspection graph visualization with improved node placement
2 parents bffe7f4 + 965c82b commit 27e212a

File tree

3 files changed

+195
-33
lines changed

3 files changed

+195
-33
lines changed

.github/workflows/docs.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ on: [push, pull_request]
55

66
jobs:
77
build:
8-
runs-on: ubuntu-20.04
8+
runs-on: ubuntu-latest
99
steps:
1010
- uses: actions/checkout@v2
1111
- name: "Set up Python 3.10"

data_prototype/conversion_edge.py

+33-32
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,20 @@ def edges(self):
372372
return SequenceEdge.from_edges("eval", out_edges, output)
373373

374374
def visualize(self, input: dict[str, Desc] | None = None):
375-
import networkx as nx
375+
if input is None:
376+
from .introspection import draw_graph
377+
378+
draw_graph(self)
379+
return
380+
381+
try:
382+
import networkx as nx
383+
except ImportError:
384+
from .introspection import draw_graph
385+
386+
draw_graph(self)
387+
return
388+
376389
import matplotlib.pyplot as plt
377390

378391
from pprint import pformat
@@ -382,38 +395,26 @@ def node_format(x):
382395

383396
G = nx.DiGraph()
384397

385-
if input is not None:
386-
for _, edges in self._subgraphs:
387-
q: list[dict[str, Desc]] = [input]
388-
explored: set[tuple[tuple[str, str], ...]] = set()
389-
explored.add(
390-
tuple(sorted(((k, v.coordinates) for k, v in q[0].items())))
391-
)
392-
G.add_node(node_format(q[0]))
393-
while q:
394-
n = q.pop()
395-
for e in edges:
396-
if Desc.compatible(n, e.input):
397-
w = n | e.output
398-
if node_format(w) not in G:
399-
G.add_node(node_format(w))
400-
explored.add(
401-
tuple(
402-
sorted(
403-
((k, v.coordinates) for k, v in w.items())
404-
)
405-
)
398+
for _, edges in self._subgraphs:
399+
q: list[dict[str, Desc]] = [input]
400+
explored: set[tuple[tuple[str, str], ...]] = set()
401+
explored.add(tuple(sorted(((k, v.coordinates) for k, v in q[0].items()))))
402+
G.add_node(node_format(q[0]))
403+
while q:
404+
n = q.pop()
405+
for e in edges:
406+
if Desc.compatible(n, e.input):
407+
w = n | e.output
408+
if node_format(w) not in G:
409+
G.add_node(node_format(w))
410+
explored.add(
411+
tuple(
412+
sorted(((k, v.coordinates) for k, v in w.items()))
406413
)
407-
q.append(w)
408-
if node_format(w) != node_format(n):
409-
G.add_edge(node_format(n), node_format(w), name=e.name)
410-
else:
411-
# don't bother separating subgraphs,as the end result is exactly the same here
412-
for edge in self._edges:
413-
G.add_edge(
414-
node_format(edge.input), node_format(edge.output), name=edge.name
415-
)
416-
414+
)
415+
q.append(w)
416+
if node_format(w) != node_format(n):
417+
G.add_edge(node_format(n), node_format(w), name=e.name)
417418
try:
418419
pos = nx.shell_layout(G)
419420
except Exception:

data_prototype/introspection.py

+161
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass, field
4+
import graphlib
5+
from pprint import pformat
6+
7+
import matplotlib.pyplot as plt
8+
9+
from .conversion_edge import Edge, Graph
10+
from .description import Desc
11+
12+
13+
@dataclass
14+
class VisNode:
15+
keys: list[str]
16+
coordinates: list[str]
17+
parents: list[VisNode] = field(default_factory=list)
18+
children: list[VisNode] = field(default_factory=list)
19+
x: int = 0
20+
y: int = 0
21+
22+
def __eq__(self, other):
23+
return self.keys == other.keys and self.coordinates == other.coordinates
24+
25+
def format(self):
26+
return pformat({k: v for k, v in zip(self.keys, self.coordinates)}, width=20)
27+
28+
29+
@dataclass
30+
class VisEdge:
31+
name: str
32+
parent: VisNode
33+
child: VisNode
34+
35+
36+
def _position_subgraph(
37+
subgraph: tuple(set[str], list[Edge]),
38+
) -> tuple[list[VisNode], list[VisEdge]]:
39+
# Build graph
40+
nodes: list[VisNode] = []
41+
edges: list[VisEdge] = []
42+
43+
q: list[dict[str, Desc]] = [e.input for e in subgraph[1]]
44+
explored: set[tuple[tuple[str, str], ...]] = set()
45+
explored.add(tuple(sorted(((k, v.coordinates) for k, v in q[0].items()))))
46+
47+
for e in subgraph[1]:
48+
nodes.append(
49+
VisNode(list(e.input.keys()), [x.coordinates for x in e.input.values()])
50+
)
51+
52+
while q:
53+
n = q.pop()
54+
vn = VisNode(list(n.keys()), [x.coordinates for x in n.values()])
55+
for nn in nodes:
56+
if vn == nn:
57+
vn = nn
58+
59+
for e in subgraph[1]:
60+
# Shortcut default edges appearing all over the place
61+
if e.input == {} and vn.keys != []:
62+
continue
63+
if Desc.compatible(n, e.input):
64+
w = e.output
65+
vw = VisNode(list(w.keys()), [x.coordinates for x in w.values()])
66+
for nn in nodes:
67+
if vw == nn:
68+
vw = nn
69+
70+
if vw not in nodes:
71+
nodes.append(vw)
72+
explored.add(
73+
tuple(sorted(((k, v.coordinates) for k, v in w.items())))
74+
)
75+
q.append(w)
76+
if vw != vn:
77+
edges.append(VisEdge(e.name, vn, vw))
78+
vw.parents.append(vn)
79+
vn.children.append(vw)
80+
81+
# adapt graph for total ording
82+
def hash_node(n):
83+
return (tuple(n.keys), tuple(n.coordinates))
84+
85+
to_graph = {hash_node(n): set() for n in nodes}
86+
for e in edges:
87+
to_graph[hash_node(e.child)] |= {hash_node(e.parent)}
88+
89+
# evaluate total ordering
90+
topological_sorter = graphlib.TopologicalSorter(to_graph)
91+
92+
# position horizontally by 1+ highest parent, vertically by 1+ highest sibling
93+
def get_node(n):
94+
for node in nodes:
95+
if n[0] == tuple(node.keys) and n[1] == tuple(node.coordinates):
96+
return node
97+
98+
static_order = list(topological_sorter.static_order())
99+
100+
for n in static_order:
101+
node = get_node(n)
102+
if node.parents != []:
103+
node.y = max(p.y for p in node.parents) + 1
104+
x_pos = {}
105+
for n in static_order:
106+
node = get_node(n)
107+
if node.y in x_pos:
108+
node.x = x_pos[node.y]
109+
x_pos[node.y] += 1.25
110+
else:
111+
x_pos[node.y] = 1.25
112+
113+
return nodes, edges
114+
115+
116+
def draw_graph(graph: Graph, ax=None, *, adjust_axes=None):
117+
if ax is None:
118+
fig, ax = plt.subplots()
119+
if adjust_axes is None:
120+
adjust_axes = True
121+
122+
inverted = adjust_axes or ax.yaxis.get_inverted()
123+
124+
origin_y = 0
125+
xmax = 0
126+
127+
for sg in graph._subgraphs:
128+
nodes, edges = _position_subgraph(sg)
129+
annotations = {}
130+
# Draw nodes
131+
for node in nodes:
132+
annotations[node.format()] = ax.annotate(
133+
node.format(),
134+
(node.x, node.y + origin_y),
135+
ha="center",
136+
va="center",
137+
bbox={"boxstyle": "round", "facecolor": "none"},
138+
)
139+
140+
# Draw edges
141+
for edge in edges:
142+
arr = ax.annotate(
143+
"",
144+
(0.5, 1.05 if inverted else -0.05),
145+
(0.5, -0.05 if inverted else 1.05),
146+
xycoords=annotations[edge.child.format()],
147+
textcoords=annotations[edge.parent.format()],
148+
arrowprops={"arrowstyle": "->"},
149+
annotation_clip=True,
150+
)
151+
ax.annotate(edge.name, (0.5, 0.5), xytext=(0.5, 0.5), textcoords=arr)
152+
153+
origin_y += max(node.y for node in nodes) + 1
154+
xmax = max(xmax, max(node.x for node in nodes))
155+
156+
if adjust_axes:
157+
ax.set_ylim(origin_y, -1)
158+
ax.set_xlim(-1, xmax + 1)
159+
ax.spines[:].set_visible(False)
160+
ax.set_xticks([])
161+
ax.set_yticks([])

0 commit comments

Comments
 (0)