Skip to content

Commit 7b8d86a

Browse files
authored
Add helper function for binary to 1-hot (#466)
* Add helper function for binary to 1-hot
1 parent e25b20c commit 7b8d86a

File tree

4 files changed

+85
-1
lines changed

4 files changed

+85
-1
lines changed

docs/helpers.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,5 @@ Encoders and Decoders
111111
---------------------
112112

113113
.. autofunction:: pyrtl.helperfuncs.one_hot_to_binary
114+
.. autofunction:: pyrtl.helperfuncs.binary_to_one_hot
114115

pyrtl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from .helperfuncs import wire_struct
4242
from .helperfuncs import wire_matrix
4343
from .helperfuncs import one_hot_to_binary
44+
from .helperfuncs import binary_to_one_hot
4445

4546
from .corecircuits import and_all_bits
4647
from .corecircuits import or_all_bits

pyrtl/helperfuncs.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,15 @@
1313
from .core import working_block, _NameIndexer, _get_debug_mode, Block
1414
from .pyrtlexceptions import PyrtlError, PyrtlInternalError
1515
from .wire import WireVector, Input, Output, Const, Register, WrappedWireVector
16-
from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list, select
16+
from .corecircuits import (
17+
as_wires,
18+
rtl_all,
19+
rtl_any,
20+
concat,
21+
concat_list,
22+
select,
23+
shift_left_logical
24+
)
1725

1826
# -----------------------------------------------------------------
1927
# ___ __ ___ __ __
@@ -1715,3 +1723,35 @@ def one_hot_to_binary(w) -> WireVector:
17151723
already_found = already_found | w[i]
17161724

17171725
return pos
1726+
1727+
1728+
def binary_to_one_hot(bit_position, max_bitwidth: int = None) -> WireVector:
1729+
'''Takes an input representing a bit position and returns a WireVector
1730+
with that bit position set to 1 and the others to 0.
1731+
1732+
:param bit_position: WireVector, WireVector-like object, or something that can be converted
1733+
into a :py:class:`.Const` (in accordance with the :py:func:`.as_wires()`
1734+
required input). Example inputs: ``0b10``, ``0b1000``, ``4``.
1735+
:param max_bitwidth: Optional integer maximum bitwidth for the resulting one-hot WireVector.
1736+
:return: WireVector with the bit position given by the input set to 1 and all other bits
1737+
set to 0 (bit position 0 being the least significant bit).
1738+
1739+
If the max_bitwidth provided is not sufficient for the given bit_position to be set to 1,
1740+
a ``0`` WireVector of size max_bitwidth will be returned.
1741+
1742+
Examples::
1743+
1744+
binary_to_onehot(0) # returns 0b01
1745+
binary_to_onehot(3) # returns 0b1000
1746+
binary_to_onehot(0b100) # returns 0b10000
1747+
'''
1748+
1749+
bit_position = as_wires(bit_position)
1750+
1751+
if max_bitwidth is not None:
1752+
bitwidth = max_bitwidth
1753+
else:
1754+
bitwidth = 2 ** len(bit_position)
1755+
1756+
# Need to dynamically set the appropriate bit position since bit_position may not be a Const
1757+
return shift_left_logical(Const(1, bitwidth=bitwidth), bit_position)

tests/test_helperfuncs.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1814,5 +1814,47 @@ def test_no_ones(self):
18141814
self.assertEqual(sim.inspect('o'), 0)
18151815

18161816

1817+
class TestBinaryToOneHot(unittest.TestCase):
1818+
def setUp(self):
1819+
pyrtl.reset_working_block()
1820+
1821+
def test_simple_binary_to_one_hot(self):
1822+
bit_position = pyrtl.Input(bitwidth=8, name='bit_position')
1823+
one_hot = pyrtl.Output(name='one_hot')
1824+
one_hot <<= pyrtl.binary_to_one_hot(bit_position)
1825+
1826+
self.assertEqual(one_hot.bitwidth, 256)
1827+
1828+
sim = pyrtl.Simulation()
1829+
sim.step({bit_position: 0})
1830+
self.assertEqual(sim.inspect('one_hot'), 0b01)
1831+
sim.step({bit_position: 2})
1832+
self.assertEqual(sim.inspect('one_hot'), 0b0100)
1833+
sim.step({bit_position: 5})
1834+
self.assertEqual(sim.inspect('one_hot'), 0b00100000)
1835+
sim.step({bit_position: 12})
1836+
self.assertEqual(sim.inspect('one_hot'), 0b0001000000000000)
1837+
sim.step({bit_position: 15})
1838+
self.assertEqual(sim.inspect('one_hot'), 0b1000000000000000)
1839+
1840+
# Tests with the max_bitwidth set
1841+
def test_with_max_bitwidth(self):
1842+
bit_position = pyrtl.Input(bitwidth=8, name='bit_position')
1843+
one_hot = pyrtl.Output(name='one_hot')
1844+
one_hot <<= pyrtl.binary_to_one_hot(bit_position, max_bitwidth=4)
1845+
1846+
self.assertEqual(one_hot.bitwidth, 4)
1847+
1848+
sim = pyrtl.Simulation()
1849+
sim.step({bit_position: 0})
1850+
self.assertEqual(sim.inspect('one_hot'), 0b0001)
1851+
sim.step({bit_position: 3})
1852+
self.assertEqual(sim.inspect('one_hot'), 0b1000)
1853+
1854+
# The max_bitwidth set is not enough for a bit position of 4
1855+
sim.step({bit_position: 4})
1856+
self.assertEqual(sim.inspect('one_hot'), 0b0000)
1857+
1858+
18171859
if __name__ == "__main__":
18181860
unittest.main()

0 commit comments

Comments
 (0)