Skip to content

Commit

Permalink
Merge pull request #849 from phanrahan/mux-infer-intv
Browse files Browse the repository at this point in the history
Improve mux inference
  • Loading branch information
leonardt authored Sep 22, 2020
2 parents b519419 + 05745ab commit 6f8820b
Show file tree
Hide file tree
Showing 44 changed files with 1,087 additions and 1,103 deletions.
54 changes: 40 additions & 14 deletions magma/primitives/mux.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import hwtypes as ht
from hwtypes import BitVector
from hwtypes import BitVector, UIntVector, SIntVector
from magma.array import Array
from magma.bit import Bit
from magma.bits import Bits
from magma.bits import Bits, UInt, SInt
from magma.bitutils import clog2, seq2int
from magma.circuit import coreir_port_mapping
from magma.generator import Generator2
from magma.interface import IO
from magma.protocol_type import MagmaProtocol, magma_type
from magma.t import Type, In, Out
from magma.tuple import Product
from magma.t import Type, In, Out, Direction
from magma.tuple import Product, Tuple
from magma.conversions import tuple_
from magma.wireable import wireable


class CoreIRCommonLibMuxN(Generator2):
Expand Down Expand Up @@ -84,18 +85,43 @@ def _infer_mux_type(args):
Note that we do not infer from standard python int arguments because we
cannot, in general, determine the correct bit width (use BitVector instead)
"""
T = None
for arg in args:
if isinstance(arg, (Type, MagmaProtocol)):
return type(arg), args
if isinstance(arg, BitVector):
return Bits[len(arg)], args
if isinstance(arg, (ht.Bit, bool)):
return Bit, args
if isinstance(arg, tuple):
return type(tuple_(arg)), [tuple_(a) for a in args]
raise TypeError(
f"Could not infer mux type from {args}\n"
"Need at least one magma value, BitVector, bool or tuple")
next_T = type(arg).qualify(Direction.Undirected)
elif isinstance(arg, UIntVector):
next_T = UInt[len(arg)]
elif isinstance(arg, SIntVector):
next_T = SInt[len(arg)]
elif isinstance(arg, BitVector):
next_T = Bits[len(arg)]
elif isinstance(arg, (ht.Bit, bool)):
next_T = Bit
elif isinstance(arg, tuple):
next_T = type(tuple_(arg))
elif isinstance(arg, int):
# Cannot infer type without width, use wiring implicit coercion to
# handle (or raise type error there)
continue

if T is not None:
if issubclass(T, next_T):
# upcast
T = next_T
elif not wireable(next_T, T):
raise TypeError(
f"Found incompatible types {next_T} and {T} in mux"
" inference"
)
else:
T = next_T
if T is None:
raise TypeError(
f"Could not infer mux type from {args}\n"
"Need at least one magma value, BitVector, bool or tuple")
if issubclass(T, Tuple):
args = [tuple_(a) for a in args]
return T, args


def mux(I: list, S, **kwargs):
Expand Down
15 changes: 2 additions & 13 deletions magma/primitives/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
AsyncReset, AsyncResetN, Clock, get_reset_args)
from magma.clock_io import ClockIO
from magma.primitives.mux import Mux
from magma.wireable import wireable


class _CoreIRRegister(Generator2):
Expand Down Expand Up @@ -109,18 +110,6 @@ def _get_T_from_init(init):
raise ValueError("Could not infer register type from {init}")


def _can_wire_types(T1, T2):
if issubclass(T1, Tuple):
if not issubclass(T2, Tuple):
return False
return all(_can_wire_types(t1, t2) for t1, t2 in zip(T1, T2))
if issubclass(T1, Array):
if not issubclass(T2, Array):
return False
return _can_wire_types(T1.T, T2.T)
return issubclass(T1, T2) or issubclass(T1, T2)


def _check_init_T(init, T):
init_T = _get_T_from_init(init)
if isinstance(init, int) and issubclass(T, Bits):
Expand All @@ -134,7 +123,7 @@ def _check_init_T(init, T):
if len(init_T) > 1:
return False
return True
return _can_wire_types(init_T, T)
return wireable(init_T, T)


class Register(Generator2):
Expand Down
19 changes: 19 additions & 0 deletions magma/wireable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from magma.array import Array
from magma.tuple import Tuple
from magma.protocol_type import magma_type


def wireable(T1, T2):
"""
Returns true if T1 can be wired to T2
"""
T1, T2 = magma_type(T1), magma_type(T2)
if issubclass(T1, Tuple):
if not issubclass(T2, Tuple):
return False
return all(wireable(t1, t2) for t1, t2 in zip(T1, T2))
if issubclass(T1, Array):
if not issubclass(T2, Array):
return False
return wireable(T1.T, T2.T)
return issubclass(T1, T2) or issubclass(T1, T2)
2 changes: 1 addition & 1 deletion tests/test_errors/test_mux_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ class Foo(m.Circuit):
with pytest.raises(TypeError) as e:
m.mux([io.I0, io.I1], io.S)
assert str(e.value) == f"""\
mux arg I[1] (I1: Out(Bits[3])) does not match inferred input port type Out(Bits[2])\
mux arg I[1] (I1: Out(Bits[3])) does not match inferred input port type Bits[2]\
"""
10 changes: 5 additions & 5 deletions tests/test_issues/gold/test_708_inline_False.v
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ coreir_mux #(
assign out = _join_out;
endmodule

module Mux2xTuplex_OutUInt8 (
module Mux2xTuplex_UInt8 (
input [7:0] I0_x,
input [7:0] I1_x,
output [7:0] O_x,
Expand All @@ -86,13 +86,13 @@ module Test_comb (
input c,
input [7:0] self_a_O_x
);
wire [7:0] Mux2xTuplex_OutUInt8_inst0_O_x;
wire [7:0] Mux2xTuplex_UInt8_inst0_O_x;
wire [7:0] const_1_8_out;
wire [7:0] magma_Bits_8_add_inst0_out;
Mux2xTuplex_OutUInt8 Mux2xTuplex_OutUInt8_inst0 (
Mux2xTuplex_UInt8 Mux2xTuplex_UInt8_inst0 (
.I0_x(self_a_O_x),
.I1_x(magma_Bits_8_add_inst0_out),
.O_x(Mux2xTuplex_OutUInt8_inst0_O_x),
.O_x(Mux2xTuplex_UInt8_inst0_O_x),
.S(c)
);
coreir_const #(
Expand All @@ -109,7 +109,7 @@ coreir_add #(
.out(magma_Bits_8_add_inst0_out)
);
assign O0_x = self_a_O_x;
assign O1_a_x = Mux2xTuplex_OutUInt8_inst0_O_x;
assign O1_a_x = Mux2xTuplex_UInt8_inst0_O_x;
endmodule

module Test (
Expand Down
4 changes: 2 additions & 2 deletions tests/test_issues/gold/test_708_inline_True.v
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module coreir_reg #(
assign out = outReg;
endmodule

module Mux2xTuplex_OutUInt8 (
module Mux2xTuplex_UInt8 (
input [7:0] I0_x,
input [7:0] I1_x,
output [7:0] O_x,
Expand All @@ -41,7 +41,7 @@ module Test_comb (
input [7:0] self_a_O_x
);
wire [7:0] magma_Bits_8_add_inst0_out;
Mux2xTuplex_OutUInt8 Mux2xTuplex_OutUInt8_inst0 (
Mux2xTuplex_UInt8 Mux2xTuplex_UInt8_inst0 (
.I0_x(self_a_O_x),
.I1_x(magma_Bits_8_add_inst0_out),
.O_x(O1_a_x),
Expand Down
10 changes: 5 additions & 5 deletions tests/test_operators/gold/TestSlice.v
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ coreir_slice #(
assign out = _join_out;
endmodule

module Mux4xOutBits6 (
module Mux4xBits6 (
input [5:0] I0,
input [5:0] I1,
input [5:0] I2,
Expand All @@ -117,15 +117,15 @@ module TestSlice (
input [1:0] x,
output [5:0] O
);
wire [5:0] Mux4xOutBits6_inst0_O;
Mux4xOutBits6 Mux4xOutBits6_inst0 (
wire [5:0] Mux4xBits6_inst0_O;
Mux4xBits6 Mux4xBits6_inst0 (
.I0(I[5:0]),
.I1(I[6:1]),
.I2(I[7:2]),
.I3(I[8:3]),
.S(x),
.O(Mux4xOutBits6_inst0_O)
.O(Mux4xBits6_inst0_O)
);
assign O = Mux4xOutBits6_inst0_O;
assign O = Mux4xBits6_inst0_O;
endmodule

10 changes: 5 additions & 5 deletions tests/test_primitives/gold/test_mux_operator.v
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ coreir_mux #(
assign out = _join_out;
endmodule

module Mux2xOutBit (
module Mux2xBit (
input I0,
input I1,
input S,
Expand All @@ -49,13 +49,13 @@ module test_mux_operator (
input S,
output O
);
wire Mux2xOutBit_inst0_O;
Mux2xOutBit Mux2xOutBit_inst0 (
wire Mux2xBit_inst0_O;
Mux2xBit Mux2xBit_inst0 (
.I0(I[0]),
.I1(I[1]),
.S(S),
.O(Mux2xOutBit_inst0_O)
.O(Mux2xBit_inst0_O)
);
assign O = Mux2xOutBit_inst0_O;
assign O = Mux2xBit_inst0_O;
endmodule

10 changes: 5 additions & 5 deletions tests/test_primitives/gold/test_mux_operator_int.v
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ coreir_mux #(
assign out = _join_out;
endmodule

module Mux2xOutBit (
module Mux2xBit (
input I0,
input I1,
input S,
Expand All @@ -57,19 +57,19 @@ module test_mux_operator_int (
input S,
output O
);
wire Mux2xOutBit_inst0_O;
wire Mux2xBit_inst0_O;
wire bit_const_0_None_out;
Mux2xOutBit Mux2xOutBit_inst0 (
Mux2xBit Mux2xBit_inst0 (
.I0(bit_const_0_None_out),
.I1(I),
.S(S),
.O(Mux2xOutBit_inst0_O)
.O(Mux2xBit_inst0_O)
);
corebit_const #(
.value(1'b0)
) bit_const_0_None (
.out(bit_const_0_None_out)
);
assign O = Mux2xOutBit_inst0_O;
assign O = Mux2xBit_inst0_O;
endmodule

16 changes: 16 additions & 0 deletions tests/test_primitives/test_mux.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pytest
from hwtypes import BitVector
import hwtypes as ht

Expand Down Expand Up @@ -293,3 +294,18 @@ class test_mux_array_select_bits_1(m.Circuit):
tester.compile_and_run("verilator", skip_compile=True,
directory=os.path.join(os.path.dirname(__file__),
"build"))


@pytest.mark.parametrize("ht_T, m_T", [(ht.UIntVector, m.UInt),
(ht.SIntVector, m.SInt)])
def test_mux_intv(ht_T, m_T):
class Main(m.Circuit):
O = m.mux([ht_T[4](1), m_T[4](2)], m.Bit())
assert isinstance(O, m_T)


@pytest.mark.parametrize("ht_T", [ht.UIntVector, ht.SIntVector])
def test_mux_intv_bits(ht_T):
class Main(m.Circuit):
O = m.mux([ht_T[4](1), m.Bits[4](2)], m.Bit())
assert type(O) is m.Out(m.Bits[4])
14 changes: 7 additions & 7 deletions tests/test_syntax/gold/RdPtr.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"namespaces":{
"global":{
"modules":{
"Mux2xOutUInt10":{
"Mux2xUInt10":{
"type":["Record",[
["I0",["Array",10,"BitIn"]],
["I1",["Array",10,"BitIn"]],
Expand Down Expand Up @@ -56,8 +56,8 @@
["O1",["Array",10,"Bit"]]
]],
"instances":{
"Mux2xOutUInt10_inst0":{
"modref":"global.Mux2xOutUInt10"
"Mux2xUInt10_inst0":{
"modref":"global.Mux2xUInt10"
},
"const_1_10":{
"genref":"coreir.const",
Expand All @@ -70,10 +70,10 @@
}
},
"connections":[
["self.self_rd_ptr_O","Mux2xOutUInt10_inst0.I0"],
["magma_Bits_10_add_inst0.out","Mux2xOutUInt10_inst0.I1"],
["self.O0","Mux2xOutUInt10_inst0.O"],
["self.read","Mux2xOutUInt10_inst0.S"],
["self.self_rd_ptr_O","Mux2xUInt10_inst0.I0"],
["magma_Bits_10_add_inst0.out","Mux2xUInt10_inst0.I1"],
["self.O0","Mux2xUInt10_inst0.O"],
["self.read","Mux2xUInt10_inst0.S"],
["magma_Bits_10_add_inst0.in1","const_1_10.out"],
["self.self_rd_ptr_O","magma_Bits_10_add_inst0.in0"],
["self.self_rd_ptr_O","self.O1"]
Expand Down
10 changes: 5 additions & 5 deletions tests/test_syntax/gold/RdPtr.v
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ coreir_mux #(
assign out = _join_out;
endmodule

module Mux2xOutUInt10 (
module Mux2xUInt10 (
input [9:0] I0,
input [9:0] I1,
input S,
Expand All @@ -92,14 +92,14 @@ module RdPtr_comb (
output [9:0] O0,
output [9:0] O1
);
wire [9:0] Mux2xOutUInt10_inst0_O;
wire [9:0] Mux2xUInt10_inst0_O;
wire [9:0] const_1_10_out;
wire [9:0] magma_Bits_10_add_inst0_out;
Mux2xOutUInt10 Mux2xOutUInt10_inst0 (
Mux2xUInt10 Mux2xUInt10_inst0 (
.I0(self_rd_ptr_O),
.I1(magma_Bits_10_add_inst0_out),
.S(read),
.O(Mux2xOutUInt10_inst0_O)
.O(Mux2xUInt10_inst0_O)
);
coreir_const #(
.value(10'h001),
Expand All @@ -114,7 +114,7 @@ coreir_add #(
.in1(const_1_10_out),
.out(magma_Bits_10_add_inst0_out)
);
assign O0 = Mux2xOutUInt10_inst0_O;
assign O0 = Mux2xUInt10_inst0_O;
assign O1 = self_rd_ptr_O;
endmodule

Expand Down
Loading

0 comments on commit 6f8820b

Please sign in to comment.