Skip to content

Commit

Permalink
[MLIR] Add support for sync_read_mem (#1203)
Browse files Browse the repository at this point in the history
Co-authored-by: rsetaluri <[email protected]>
  • Loading branch information
leonardt and rsetaluri authored Mar 31, 2023
1 parent e093868 commit df5b4bb
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 16 deletions.
55 changes: 39 additions & 16 deletions magma/backend/mlir/hardware_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
visit_value_or_value_wrapper_by_direction as
visit_magma_value_or_value_wrapper_by_direction,
)
from magma.backend.mlir.mem_utils import (
make_mem_reg,
make_mem_read,
emit_conditional_assign,
make_index_op,
)
from magma.backend.mlir.mlir import (
MlirType, MlirValue, MlirSymbol, MlirAttribute, MlirBlock, push_block,
push_location
Expand Down Expand Up @@ -366,29 +372,35 @@ def make_concat(self, operands, result):
def visit_coreir_mem(self, module: ModuleWrapper) -> bool:
inst = module.module
defn = type(inst)
assert defn.coreir_name == "mem"
assert (
defn.coreir_name == "mem"
or defn.coreir_name == "sync_read_mem"
)
# TODO(rsetaluri): Add support for initialization.
if defn.coreir_genargs["has_init"]:
raise NotImplementedError("coreir.mem init not supported")
width = defn.coreir_genargs["width"]
depth = defn.coreir_genargs["depth"]
raddr, waddr, wdata, clk, wen = module.operands
is_sync_read_mem = (defn.coreir_name == "sync_read_mem")
raddr, waddr, wdata, clk, wen = module.operands[:5]
rdata = module.results[0]
elt_type = hw.InOutType(builtin.IntegerType(width))
reg_type = hw.InOutType(hw.ArrayType((depth,), elt_type.T))
reg = self.ctx.new_value(reg_type)
sv.RegOp(name=inst.name, results=[reg])
mem = make_mem_reg(
self._ctx, inst.name, depth, builtin.IntegerType(width)
)
# Register read logic.
read = self.ctx.new_value(elt_type)
sv.ArrayIndexInOutOp(operands=[reg, raddr], results=[read])
sv.ReadInOutOp(operands=[read], results=[rdata])
read = make_index_op(self._ctx, mem, raddr)
read_reg, read_temp = make_mem_read(
self._ctx, read, rdata, is_sync_read_mem
)
# Register write logic.
write = self.ctx.new_value(elt_type)
sv.ArrayIndexInOutOp(operands=[reg, waddr], results=[write])
write = make_index_op(self._ctx, mem, waddr)
# Always logic.
always = sv.AlwaysFFOp(operands=[clk], clock_edge="posedge").body_block
with push_block(always):
with push_block(sv.IfOp(operands=[wen]).then_block):
sv.PAssignOp(operands=[write, wdata])
emit_conditional_assign(write, wdata, wen)
if is_sync_read_mem:
ren = module.operands[-1]
emit_conditional_assign(read_reg, read_temp, ren)
return True

@wrap_with_not_implemented_error
Expand Down Expand Up @@ -498,8 +510,15 @@ def _visit(value, counter):
def visit_coreir_primitive(self, module: ModuleWrapper) -> bool:
inst = module.module
defn = type(inst)
assert (defn.coreir_lib == "coreir" or defn.coreir_lib == "corebit")
if defn.coreir_name == "mem":
assert (
defn.coreir_lib == "coreir"
or defn.coreir_lib == "corebit"
or defn.coreir_lib == "memory"
)
if (
defn.coreir_name == "mem"
or defn.coreir_name == "sync_read_mem"
):
return self.visit_coreir_mem(module)
if defn.coreir_name == "not":
return self.visit_coreir_not(module)
Expand Down Expand Up @@ -660,7 +679,11 @@ def visit_primitive(self, module: ModuleWrapper) -> bool:
inst = module.module
defn = type(inst)
assert isprimitive(defn)
if defn.coreir_lib == "coreir" or defn.coreir_lib == "corebit":
if (
defn.coreir_lib == "coreir"
or defn.coreir_lib == "corebit"
or defn.coreir_lib == "memory"
):
return self.visit_coreir_primitive(module)
if defn.coreir_lib == "commonlib":
return self.visit_commonlib_primitive(module)
Expand Down
33 changes: 33 additions & 0 deletions magma/backend/mlir/mem_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from magma.backend.mlir.hw import hw
from magma.backend.mlir.mlir import push_block
from magma.backend.mlir.sv import sv


def make_mem_reg(ctx, name, N, T):
mem_type = hw.InOutType(hw.ArrayType((N,), T))
mem = ctx.new_value(mem_type)
sv.RegOp(name=name, results=[mem])
return mem


def make_mem_read(ctx, target, value, has_enable, name="read_reg"):
if not has_enable:
sv.ReadInOutOp(operands=[target], results=[value])
return None, None
reg_out = ctx.new_value(target.type.T)
sv.ReadInOutOp(operands=[target], results=[reg_out])
reg = ctx.new_value(target.type)
sv.RegOp(name=name, results=[reg])
sv.ReadInOutOp(operands=[reg], results=[value])
return reg, reg_out


def emit_conditional_assign(target, value, en):
with push_block(sv.IfOp(operands=[en]).then_block):
sv.PAssignOp(operands=[target, value])


def make_index_op(ctx, value, idx):
result = ctx.new_value(hw.InOutType(value.type.T.T))
sv.ArrayIndexInOutOp(operands=[value, idx], results=[result])
return result
22 changes: 22 additions & 0 deletions tests/test_backend/test_mlir/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,28 @@ class simple_memory_wrapper(m.Circuit):
)


class sync_memory_wrapper(m.Circuit):
T = m.Bits[12]
height = 128
addr_type = m.Bits[m.bitutils.clog2(height)]
io = m.IO(
RADDR=m.In(addr_type),
RDATA=m.Out(T),
CLK=m.In(m.Clock),
WADDR=m.In(addr_type),
WDATA=m.In(T),
WE=m.In(m.Enable),
)
mem = m.Memory(height=height, T=T, read_latency=1)()
io.RDATA @= mem(
RADDR=io.RADDR,
RDATA=io.RDATA,
WADDR=io.WADDR,
WDATA=io.WDATA,
WE=io.WE,
)


m.passes.clock.WireClockPass(simple_memory_wrapper).run()


Expand Down
24 changes: 24 additions & 0 deletions tests/test_backend/test_mlir/golds/sync_memory_wrapper.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module attributes {circt.loweringOptions = "locationInfoStyle=none"} {
hw.module @Memory(%RADDR: i7, %CLK: i1, %WADDR: i7, %WDATA: i12, %WE: i1) -> (RDATA: i12) {
%0 = hw.constant 1 : i1
%2 = sv.reg name "coreir_mem128x12_inst0" : !hw.inout<!hw.array<128xi12>>
%3 = sv.array_index_inout %2[%RADDR] : !hw.inout<!hw.array<128xi12>>, i7
%4 = sv.read_inout %3 : !hw.inout<i12>
%5 = sv.reg name "read_reg" : !hw.inout<i12>
%1 = sv.read_inout %5 : !hw.inout<i12>
%6 = sv.array_index_inout %2[%WADDR] : !hw.inout<!hw.array<128xi12>>, i7
sv.alwaysff(posedge %CLK) {
sv.if %WE {
sv.passign %6, %WDATA : i12
}
sv.if %0 {
sv.passign %5, %4 : i12
}
}
hw.output %1 : i12
}
hw.module @sync_memory_wrapper(%RADDR: i7, %CLK: i1, %WADDR: i7, %WDATA: i12, %WE: i1) -> (RDATA: i12) {
%0 = hw.instance "Memory_inst0" @Memory(RADDR: %RADDR: i7, CLK: %CLK: i1, WADDR: %WADDR: i7, WDATA: %WDATA: i12, WE: %WE: i1) -> (RDATA: i12)
hw.output %0 : i12
}
}
39 changes: 39 additions & 0 deletions tests/test_backend/test_mlir/golds/sync_memory_wrapper.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Generated by CIRCT circtorg-0.0.0-1773-g7abbc4313
module Memory(
input [6:0] RADDR,
input CLK,
input [6:0] WADDR,
input [11:0] WDATA,
input WE,
output [11:0] RDATA
);

reg [127:0][11:0] coreir_mem128x12_inst0;
reg [11:0] read_reg;
always_ff @(posedge CLK) begin
if (WE)
coreir_mem128x12_inst0[WADDR] <= WDATA;
read_reg <= coreir_mem128x12_inst0[RADDR];
end // always_ff @(posedge)
assign RDATA = read_reg;
endmodule

module sync_memory_wrapper(
input [6:0] RADDR,
input CLK,
input [6:0] WADDR,
input [11:0] WDATA,
input WE,
output [11:0] RDATA
);

Memory Memory_inst0 (
.RADDR (RADDR),
.CLK (CLK),
.WADDR (WADDR),
.WDATA (WDATA),
.WE (WE),
.RDATA (RDATA)
);
endmodule

1 change: 1 addition & 0 deletions tests/test_backend/test_mlir/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def get_local_examples() -> List[DefineCircuitKind]:
examples.simple_undriven,
examples.complex_undriven,
examples.simple_memory_wrapper,
examples.sync_memory_wrapper,
examples.simple_undriven_instances,
examples.simple_neg,
examples.simple_array_slice,
Expand Down

0 comments on commit df5b4bb

Please sign in to comment.