Skip to content

Commit d5f8dbe

Browse files
authored
Eliminate double inverts in pyrtl.optimize (#462)
1 parent 2f71715 commit d5f8dbe

File tree

2 files changed

+232
-0
lines changed

2 files changed

+232
-0
lines changed

pyrtl/passes.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,162 @@ def optimize(update_working_block=True, block=None, skip_sanity_check=False):
5252
constant_propagation(block, True)
5353
_remove_unlistened_nets(block)
5454
common_subexp_elimination(block)
55+
_optimize_inverter_chains(block, skip_sanity_check)
5556
if (not skip_sanity_check) or _get_debug_mode():
5657
block.sanity_check()
5758
return block
5859

5960

61+
def _get_inverter_chains(wire_creator, wire_users):
62+
"""Returns all inverter chains in the block.
63+
64+
The function returns a list of inverter chains in the block.
65+
Each inverter chain is represented as a list of the WireVectors
66+
in the chain.
67+
68+
Consider the following circuit, for example:
69+
A -~-> B -~-> C -w-> X
70+
D -~-> E -w-> Y
71+
If the function is called on this circuit, it will return
72+
[[A, B, C], [D, E]].
73+
"""
74+
75+
# Build a list of inverter chains. Each inverter chain is a list of WireVectors,
76+
# from source to destination.
77+
inverter_chains = []
78+
for current_dest, current_creator in wire_creator.items():
79+
if current_creator.op != "~":
80+
# Skip non-inverters.
81+
continue
82+
83+
# The current inverter connects current_arg (a WireVector) to current_dest (also
84+
# a WireVector).
85+
current_arg = current_creator.args[0]
86+
# current_users is the number of LogicNets that use current_dest.
87+
current_users = len(wire_users[current_dest])
88+
89+
# Add the current inverter to the end of this inverter chain.
90+
append_to = None
91+
# Add the current inverter to the beginning of this inverter chain.
92+
prepend_to = None
93+
next_inverter_chains = []
94+
for inverter_chain in inverter_chains:
95+
chain_arg = inverter_chain[0]
96+
chain_dest = inverter_chain[-1]
97+
chain_users = len(wire_users[chain_dest])
98+
99+
if chain_dest is current_arg and chain_users == 1:
100+
# This chain's only destination is the current inverter. Append the
101+
# current inverter to the chain.
102+
append_to = inverter_chain
103+
elif chain_arg is current_dest and current_users == 1:
104+
# This chain's only argument is the current inverter. Add the current
105+
# inverter to the beginning of the chain.
106+
prepend_to = inverter_chain
107+
else:
108+
# The current inverter is not connected to the inverter chain, so we
109+
# pass the inverter chain through to next_inverter_chains
110+
next_inverter_chains.append(inverter_chain)
111+
112+
if append_to and prepend_to:
113+
# The current inverter joins two existing inverter chains.
114+
next_inverter_chains.append(append_to + prepend_to)
115+
elif append_to:
116+
# Add the current inverter after 'append_to'.
117+
next_inverter_chains.append(append_to + [current_dest])
118+
elif prepend_to:
119+
# Add the current inverter before 'prepend_to'.
120+
next_inverter_chains.append([current_arg] + prepend_to)
121+
else:
122+
# The current inverter is not connected to any inverter chain, so
123+
# we start a new inverter chain with it
124+
next_inverter_chains.append([current_arg, current_dest])
125+
126+
inverter_chains = next_inverter_chains
127+
return inverter_chains
128+
129+
130+
def _optimize_inverter_chains(block, skip_sanity_check=False):
131+
""" Optimizes inverter chains in the block.
132+
133+
An inverter chain means two or more inverters directly connected
134+
to each other. Inverter chains are redundant and can be removed.
135+
For example, A -~-> B -~-> C -w-> X can be reduced to A -w-> X.
136+
137+
After optimization, a chain of an even number of inverters will
138+
be reduced a direct connection, and a chain of an odd number of
139+
inverters will be reduced to one inverter.
140+
141+
If an inverter chain has intermediate users it won't be removed.
142+
For example, the inverter chain in the following circuit won't be removed:
143+
A -~-> B -~-> C -w-> X
144+
B -w-> Y
145+
"""
146+
147+
# wire_creator maps from WireVector to the LogicNet that defines its value.
148+
# wire_users maps from WireVector to a list of LogicNets that use its value.
149+
wire_creator, wire_users = block.net_connections()
150+
151+
new_logic = set()
152+
net_removal_set = set()
153+
wire_removal_set = set()
154+
155+
# This ProducerList maps the end wire of an inverter chain to its beginning wire.
156+
# We need this because when removing an inverter chain its end wire gets removed,
157+
# so we need to replace the source of LogicNets using the end wire of the inverter
158+
# chain with the chain's beginning wire.
159+
#
160+
# We need a ProducerList, rather than a simple dict, because if an inverter chain
161+
# of more than two inverters has intermediate users, we may have to query the dict
162+
# multiple times to get the replacement for the inverter chain's last wire.
163+
# Consider the following circuit, for example:
164+
# A -~-> B -~-> C -w-> X
165+
# C -~-> D -~-> E -w-> Y
166+
# This is the optimized version of the circuit:
167+
# A -w-> X
168+
# A -w-> Y
169+
# The inverter chains found will be A-B-C and C-D-E (two separate chains will be
170+
# found instead of A-B-C-D-E because C has an intermediate user). In the dict,
171+
# C will be mapped to A and E will be mapped to C. Hence, when finding the
172+
# replacement of E, we have to first query the dict to get C, and then query
173+
# the dict again on C to get A.
174+
wire_src_dict = _ProducerList()
175+
176+
for inverter_chain in _get_inverter_chains(wire_creator, wire_users):
177+
# If len(inverter_chain) = n, there are n-1 inverters in the chain.
178+
# We only remove inverters if there are at least two inverters in a chain.
179+
if len(inverter_chain) > 2:
180+
if len(inverter_chain) % 2 == 1: # There is an even number of inverters in a chain.
181+
start_idx = 1
182+
else: # There is an odd number of inverters in a chain.
183+
start_idx = 2
184+
# Remove wires used in the inverter chain.
185+
wires_to_remove = inverter_chain[start_idx:]
186+
wire_removal_set.update(wires_to_remove)
187+
# Remove inverters used in the chain.
188+
inverters_to_remove = {wire_creator[wire] for wire in wires_to_remove}
189+
net_removal_set.update(inverters_to_remove)
190+
# Map the end wire of the inverter chain to the beginning wire.
191+
wire_src_dict[inverter_chain[-1]] = inverter_chain[start_idx - 1]
192+
193+
# This loop recreates the block with inverter chains removed. It adds each
194+
# LogicNet in the original block to the new block if it is not marked for
195+
# removal, and replaces the source of the LogicNet if its source was the end wire
196+
# of a removed inverter chain.
197+
for net in block.logic:
198+
if net not in net_removal_set:
199+
new_logic.add(LogicNet(net.op, net.op_param,
200+
args=tuple(wire_src_dict.find_producer(x) for x in net.args),
201+
dests=net.dests))
202+
203+
block.logic = new_logic
204+
for dead_wirevector in wire_removal_set:
205+
block.remove_wirevector(dead_wirevector)
206+
207+
if (not skip_sanity_check) or _get_debug_mode():
208+
block.sanity_check()
209+
210+
60211
class _ProducerList(object):
61212
""" Maps from wire to its immediate producer and finds ultimate producers. """
62213
def __init__(self):

tests/test_passes.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,87 @@ def test_slice_net_removal_4(self):
316316
self.num_net_of_type('s', 1, block)
317317
self.num_net_of_type('w', 2, block)
318318

319+
def test_remove_double_inverts_1_invert(self):
320+
inwire = pyrtl.Input(bitwidth=1)
321+
outwire = pyrtl.Output(bitwidth=1)
322+
outwire <<= ~inwire
323+
pyrtl.optimize()
324+
block = pyrtl.working_block()
325+
self.assert_num_net(2, block)
326+
self.assert_num_wires(3, block)
327+
328+
def test_remove_double_inverts_3_inverts(self):
329+
inwire = pyrtl.Input(bitwidth=1)
330+
outwire = pyrtl.Output(bitwidth=1)
331+
outwire <<= ~(~(~inwire))
332+
pyrtl.optimize()
333+
block = pyrtl.working_block()
334+
self.assert_num_net(2, block)
335+
self.assert_num_wires(3, block)
336+
337+
def test_remove_double_inverts_5_inverts(self):
338+
inwire = pyrtl.Input(bitwidth=1)
339+
outwire = pyrtl.Output(bitwidth=1)
340+
outwire <<= ~(~(~(~(~inwire))))
341+
pyrtl.optimize()
342+
block = pyrtl.working_block()
343+
self.assert_num_net(2, block)
344+
self.assert_num_wires(3, block)
345+
346+
def test_remove_double_inverts_2_inverts(self):
347+
inwire = pyrtl.Input(bitwidth=1)
348+
outwire = pyrtl.Output(bitwidth=1)
349+
outwire <<= ~(~inwire)
350+
pyrtl.optimize()
351+
block = pyrtl.working_block()
352+
self.assert_num_net(1, block)
353+
self.assert_num_wires(2, block)
354+
355+
def test_remove_double_inverts_4_inverts(self):
356+
inwire = pyrtl.Input(bitwidth=1)
357+
outwire = pyrtl.Output(bitwidth=1)
358+
outwire <<= ~(~(~(~inwire)))
359+
pyrtl.optimize()
360+
block = pyrtl.working_block()
361+
self.assert_num_net(1, block)
362+
self.assert_num_wires(2, block)
363+
364+
def test_remove_double_inverts_6_inverts(self):
365+
inwire = pyrtl.Input(bitwidth=1)
366+
outwire = pyrtl.Output(bitwidth=1)
367+
outwire <<= ~(~(~(~(~(~inwire)))))
368+
pyrtl.optimize()
369+
block = pyrtl.working_block()
370+
self.assert_num_net(1, block)
371+
self.assert_num_wires(2, block)
372+
373+
def test_dont_remove_double_inverts_another_user(self):
374+
inwire = pyrtl.Input(bitwidth=1)
375+
outwire = pyrtl.Output(bitwidth=1)
376+
outwire2 = pyrtl.Output(bitwidth=1)
377+
tempwire = pyrtl.WireVector()
378+
tempwire <<= ~inwire
379+
outwire <<= ~tempwire
380+
outwire2 <<= tempwire
381+
pyrtl.optimize()
382+
block = pyrtl.working_block()
383+
self.assert_num_net(4, block)
384+
self.assert_num_wires(5, block)
385+
386+
def test_multiple_double_invert_chains(self):
387+
# _remove_double_inverts removes double inverts by chains,
388+
# so it is useful to make sure it can remove
389+
# double inverts from multiple chains
390+
inwire = pyrtl.Input(bitwidth=1)
391+
outwire = pyrtl.Output(bitwidth=1)
392+
outwire2 = pyrtl.Output(bitwidth=1)
393+
outwire <<= ~(~inwire)
394+
outwire2 <<= ~(~(~(~(inwire))))
395+
pyrtl.optimize()
396+
block = pyrtl.working_block()
397+
self.assert_num_net(2, block)
398+
self.assert_num_wires(3, block)
399+
319400

320401
class TestConstFolding(NetWireNumTestCases):
321402
def setUp(self):

0 commit comments

Comments
 (0)