Skip to content

Commit df5b4bb

Browse files
leonardtrsetaluri
andauthored
[MLIR] Add support for sync_read_mem (#1203)
Co-authored-by: rsetaluri <[email protected]>
1 parent e093868 commit df5b4bb

File tree

6 files changed

+158
-16
lines changed

6 files changed

+158
-16
lines changed

magma/backend/mlir/hardware_module.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@
2626
visit_value_or_value_wrapper_by_direction as
2727
visit_magma_value_or_value_wrapper_by_direction,
2828
)
29+
from magma.backend.mlir.mem_utils import (
30+
make_mem_reg,
31+
make_mem_read,
32+
emit_conditional_assign,
33+
make_index_op,
34+
)
2935
from magma.backend.mlir.mlir import (
3036
MlirType, MlirValue, MlirSymbol, MlirAttribute, MlirBlock, push_block,
3137
push_location
@@ -366,29 +372,35 @@ def make_concat(self, operands, result):
366372
def visit_coreir_mem(self, module: ModuleWrapper) -> bool:
367373
inst = module.module
368374
defn = type(inst)
369-
assert defn.coreir_name == "mem"
375+
assert (
376+
defn.coreir_name == "mem"
377+
or defn.coreir_name == "sync_read_mem"
378+
)
370379
# TODO(rsetaluri): Add support for initialization.
371380
if defn.coreir_genargs["has_init"]:
372381
raise NotImplementedError("coreir.mem init not supported")
373382
width = defn.coreir_genargs["width"]
374383
depth = defn.coreir_genargs["depth"]
375-
raddr, waddr, wdata, clk, wen = module.operands
384+
is_sync_read_mem = (defn.coreir_name == "sync_read_mem")
385+
raddr, waddr, wdata, clk, wen = module.operands[:5]
376386
rdata = module.results[0]
377-
elt_type = hw.InOutType(builtin.IntegerType(width))
378-
reg_type = hw.InOutType(hw.ArrayType((depth,), elt_type.T))
379-
reg = self.ctx.new_value(reg_type)
380-
sv.RegOp(name=inst.name, results=[reg])
387+
mem = make_mem_reg(
388+
self._ctx, inst.name, depth, builtin.IntegerType(width)
389+
)
381390
# Register read logic.
382-
read = self.ctx.new_value(elt_type)
383-
sv.ArrayIndexInOutOp(operands=[reg, raddr], results=[read])
384-
sv.ReadInOutOp(operands=[read], results=[rdata])
391+
read = make_index_op(self._ctx, mem, raddr)
392+
read_reg, read_temp = make_mem_read(
393+
self._ctx, read, rdata, is_sync_read_mem
394+
)
385395
# Register write logic.
386-
write = self.ctx.new_value(elt_type)
387-
sv.ArrayIndexInOutOp(operands=[reg, waddr], results=[write])
396+
write = make_index_op(self._ctx, mem, waddr)
397+
# Always logic.
388398
always = sv.AlwaysFFOp(operands=[clk], clock_edge="posedge").body_block
389399
with push_block(always):
390-
with push_block(sv.IfOp(operands=[wen]).then_block):
391-
sv.PAssignOp(operands=[write, wdata])
400+
emit_conditional_assign(write, wdata, wen)
401+
if is_sync_read_mem:
402+
ren = module.operands[-1]
403+
emit_conditional_assign(read_reg, read_temp, ren)
392404
return True
393405

394406
@wrap_with_not_implemented_error
@@ -498,8 +510,15 @@ def _visit(value, counter):
498510
def visit_coreir_primitive(self, module: ModuleWrapper) -> bool:
499511
inst = module.module
500512
defn = type(inst)
501-
assert (defn.coreir_lib == "coreir" or defn.coreir_lib == "corebit")
502-
if defn.coreir_name == "mem":
513+
assert (
514+
defn.coreir_lib == "coreir"
515+
or defn.coreir_lib == "corebit"
516+
or defn.coreir_lib == "memory"
517+
)
518+
if (
519+
defn.coreir_name == "mem"
520+
or defn.coreir_name == "sync_read_mem"
521+
):
503522
return self.visit_coreir_mem(module)
504523
if defn.coreir_name == "not":
505524
return self.visit_coreir_not(module)
@@ -660,7 +679,11 @@ def visit_primitive(self, module: ModuleWrapper) -> bool:
660679
inst = module.module
661680
defn = type(inst)
662681
assert isprimitive(defn)
663-
if defn.coreir_lib == "coreir" or defn.coreir_lib == "corebit":
682+
if (
683+
defn.coreir_lib == "coreir"
684+
or defn.coreir_lib == "corebit"
685+
or defn.coreir_lib == "memory"
686+
):
664687
return self.visit_coreir_primitive(module)
665688
if defn.coreir_lib == "commonlib":
666689
return self.visit_commonlib_primitive(module)

magma/backend/mlir/mem_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from magma.backend.mlir.hw import hw
2+
from magma.backend.mlir.mlir import push_block
3+
from magma.backend.mlir.sv import sv
4+
5+
6+
def make_mem_reg(ctx, name, N, T):
7+
mem_type = hw.InOutType(hw.ArrayType((N,), T))
8+
mem = ctx.new_value(mem_type)
9+
sv.RegOp(name=name, results=[mem])
10+
return mem
11+
12+
13+
def make_mem_read(ctx, target, value, has_enable, name="read_reg"):
14+
if not has_enable:
15+
sv.ReadInOutOp(operands=[target], results=[value])
16+
return None, None
17+
reg_out = ctx.new_value(target.type.T)
18+
sv.ReadInOutOp(operands=[target], results=[reg_out])
19+
reg = ctx.new_value(target.type)
20+
sv.RegOp(name=name, results=[reg])
21+
sv.ReadInOutOp(operands=[reg], results=[value])
22+
return reg, reg_out
23+
24+
25+
def emit_conditional_assign(target, value, en):
26+
with push_block(sv.IfOp(operands=[en]).then_block):
27+
sv.PAssignOp(operands=[target, value])
28+
29+
30+
def make_index_op(ctx, value, idx):
31+
result = ctx.new_value(hw.InOutType(value.type.T.T))
32+
sv.ArrayIndexInOutOp(operands=[value, idx], results=[result])
33+
return result

tests/test_backend/test_mlir/examples.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,28 @@ class simple_memory_wrapper(m.Circuit):
511511
)
512512

513513

514+
class sync_memory_wrapper(m.Circuit):
515+
T = m.Bits[12]
516+
height = 128
517+
addr_type = m.Bits[m.bitutils.clog2(height)]
518+
io = m.IO(
519+
RADDR=m.In(addr_type),
520+
RDATA=m.Out(T),
521+
CLK=m.In(m.Clock),
522+
WADDR=m.In(addr_type),
523+
WDATA=m.In(T),
524+
WE=m.In(m.Enable),
525+
)
526+
mem = m.Memory(height=height, T=T, read_latency=1)()
527+
io.RDATA @= mem(
528+
RADDR=io.RADDR,
529+
RDATA=io.RDATA,
530+
WADDR=io.WADDR,
531+
WDATA=io.WDATA,
532+
WE=io.WE,
533+
)
534+
535+
514536
m.passes.clock.WireClockPass(simple_memory_wrapper).run()
515537

516538

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
module attributes {circt.loweringOptions = "locationInfoStyle=none"} {
2+
hw.module @Memory(%RADDR: i7, %CLK: i1, %WADDR: i7, %WDATA: i12, %WE: i1) -> (RDATA: i12) {
3+
%0 = hw.constant 1 : i1
4+
%2 = sv.reg name "coreir_mem128x12_inst0" : !hw.inout<!hw.array<128xi12>>
5+
%3 = sv.array_index_inout %2[%RADDR] : !hw.inout<!hw.array<128xi12>>, i7
6+
%4 = sv.read_inout %3 : !hw.inout<i12>
7+
%5 = sv.reg name "read_reg" : !hw.inout<i12>
8+
%1 = sv.read_inout %5 : !hw.inout<i12>
9+
%6 = sv.array_index_inout %2[%WADDR] : !hw.inout<!hw.array<128xi12>>, i7
10+
sv.alwaysff(posedge %CLK) {
11+
sv.if %WE {
12+
sv.passign %6, %WDATA : i12
13+
}
14+
sv.if %0 {
15+
sv.passign %5, %4 : i12
16+
}
17+
}
18+
hw.output %1 : i12
19+
}
20+
hw.module @sync_memory_wrapper(%RADDR: i7, %CLK: i1, %WADDR: i7, %WDATA: i12, %WE: i1) -> (RDATA: i12) {
21+
%0 = hw.instance "Memory_inst0" @Memory(RADDR: %RADDR: i7, CLK: %CLK: i1, WADDR: %WADDR: i7, WDATA: %WDATA: i12, WE: %WE: i1) -> (RDATA: i12)
22+
hw.output %0 : i12
23+
}
24+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Generated by CIRCT circtorg-0.0.0-1773-g7abbc4313
2+
module Memory(
3+
input [6:0] RADDR,
4+
input CLK,
5+
input [6:0] WADDR,
6+
input [11:0] WDATA,
7+
input WE,
8+
output [11:0] RDATA
9+
);
10+
11+
reg [127:0][11:0] coreir_mem128x12_inst0;
12+
reg [11:0] read_reg;
13+
always_ff @(posedge CLK) begin
14+
if (WE)
15+
coreir_mem128x12_inst0[WADDR] <= WDATA;
16+
read_reg <= coreir_mem128x12_inst0[RADDR];
17+
end // always_ff @(posedge)
18+
assign RDATA = read_reg;
19+
endmodule
20+
21+
module sync_memory_wrapper(
22+
input [6:0] RADDR,
23+
input CLK,
24+
input [6:0] WADDR,
25+
input [11:0] WDATA,
26+
input WE,
27+
output [11:0] RDATA
28+
);
29+
30+
Memory Memory_inst0 (
31+
.RADDR (RADDR),
32+
.CLK (CLK),
33+
.WADDR (WADDR),
34+
.WDATA (WDATA),
35+
.WE (WE),
36+
.RDATA (RDATA)
37+
);
38+
endmodule
39+

tests/test_backend/test_mlir/test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def get_local_examples() -> List[DefineCircuitKind]:
189189
examples.simple_undriven,
190190
examples.complex_undriven,
191191
examples.simple_memory_wrapper,
192+
examples.sync_memory_wrapper,
192193
examples.simple_undriven_instances,
193194
examples.simple_neg,
194195
examples.simple_array_slice,

0 commit comments

Comments
 (0)