Skip to content

Commit a92e231

Browse files
committed
move ext scratchpad sram instantiation outside
1 parent 5a10e96 commit a92e231

File tree

1 file changed

+35
-300
lines changed

1 file changed

+35
-300
lines changed

src/main/scala/gemmini/Controller.scala

Lines changed: 35 additions & 300 deletions
Original file line numberDiff line numberDiff line change
@@ -35,30 +35,9 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA
3535
val xLen = p(XLen)
3636
val spad = LazyModule(new Scratchpad(config))
3737

38-
val create_tl_mem = config.use_shared_ext_mem && config.use_tl_ext_mem
39-
38+
val use_ext_tl_mem = config.use_shared_ext_mem && config.use_tl_ext_mem
4039
val num_ids = 32 // TODO (richard): move to config
4140
val spad_base = config.tl_ext_mem_base
42-
43-
val unified_mem_read_node = TLIdentityNode()
44-
val spad_read_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.sp_banks) {i =>
45-
TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"spad_read_node_$i", sourceId = IdRange(0, num_ids))))
46-
}) else TLIdentityNode()
47-
// val acc_read_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.acc_banks) { i =>
48-
// TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"acc_read_node_$i", sourceId = IdRange(0, numIDs))))
49-
// }) else TLIdentityNode()
50-
51-
val unified_mem_write_node = TLIdentityNode()
52-
val spad_write_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.sp_banks) { i =>
53-
TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"spad_write_node_$i", sourceId = IdRange(0, num_ids))))
54-
}) else TLIdentityNode()
55-
56-
// val spad_dma_write_node = TLClientNode(Seq(
57-
// TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"spad_dma_write_node", sourceId = IdRange(0, num_ids))))))
58-
// val acc_write_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.acc_banks) { i =>
59-
// TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"acc_write_node_$i", sourceId = IdRange(0, numIDs))))
60-
// }) else TLIdentityNode()
61-
6241
val spad_data_len = config.sp_width / 8
6342
val acc_data_len = config.sp_width / config.inputType.getWidth * config.accType.getWidth / 8
6443
val max_data_len = spad_data_len // max acc_data_len
@@ -68,127 +47,41 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA
6847
require(mem_depth * mem_width * config.sp_banks == 1 << 14, f"memory size is ${mem_depth}, ${mem_width}")
6948
println(f"unified shared memory size: ${mem_depth}x${mem_width}x${config.sp_banks}")
7049

71-
// this node accepts both read and write requests,
72-
// splits & arbitrates them into one client node per type of operation
73-
val unified_mem_node = TLNexusNode(
74-
clientFn = { seq =>
75-
val in_mapping = TLXbar.mapInputIds(seq)
76-
val read_src_range = IdRange(in_mapping.map(_.start).min, in_mapping.map(_.end).max)
77-
assert((read_src_range.start == 0) && isPow2(read_src_range.end))
78-
val write_src_range = read_src_range.shift(read_src_range.size)
79-
80-
seq(0).v1copy(
81-
echoFields = BundleField.union(seq.flatMap(_.echoFields)),
82-
requestFields = BundleField.union(seq.flatMap(_.requestFields)),
83-
responseKeys = seq.flatMap(_.responseKeys).distinct,
84-
minLatency = seq.map(_.minLatency).min,
85-
clients = Seq(
86-
TLMasterParameters.v1(
87-
name = "unified_mem_read_client",
88-
sourceId = read_src_range,
89-
supportsProbe = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
90-
supportsGet = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
91-
supportsPutFull = TransferSizes.none,
92-
supportsPutPartial = TransferSizes.none
93-
),
94-
TLMasterParameters.v1(
95-
name = "unified_mem_write_client",
96-
sourceId = write_src_range,
97-
supportsProbe = TransferSizes.mincover(
98-
seq.map(_.anyEmitClaims.putFull) ++seq.map(_.anyEmitClaims.putPartial)),
99-
supportsGet = TransferSizes.none,
100-
supportsPutFull = TransferSizes.mincover(seq.map(_.anyEmitClaims.putFull)),
101-
supportsPutPartial = TransferSizes.mincover(seq.map(_.anyEmitClaims.putPartial))
102-
)
103-
)
104-
)
105-
},
106-
managerFn = { seq =>
107-
// val fifoIdFactory = TLXbar.relabeler()
108-
seq(0).v1copy(
109-
responseFields = BundleField.union(seq.flatMap(_.responseFields)),
110-
requestKeys = seq.flatMap(_.requestKeys).distinct,
111-
minLatency = seq.map(_.minLatency).min,
112-
endSinkId = TLXbar.mapOutputIds(seq).map(_.end).max,
113-
managers = Seq(TLSlaveParameters.v2(
114-
name = Some(f"unified_mem_manager"),
115-
address = Seq(AddressSet(spad_base, mem_depth * mem_width * config.sp_banks - 1)),
116-
supports = TLMasterToSlaveTransferSizes(
117-
get = TransferSizes(1, mem_width),
118-
putFull = TransferSizes(1, mem_width),
119-
putPartial = TransferSizes(1, mem_width)),
120-
fifoId = Some(0)
121-
))
122-
)
123-
}
124-
)
125-
126-
unified_mem_read_node := TLWidthWidget(spad_data_len) := unified_mem_node
127-
unified_mem_write_node := TLWidthWidget(spad_data_len) := unified_mem_node
128-
129-
val spad_tl_ram : Seq[Seq[TLManagerNode]] = if (config.use_shared_ext_mem && config.use_tl_ext_mem) {
130-
unified_mem_read_node :=* TLWidthWidget(spad_data_len) :=* spad_read_nodes
131-
// unified_mem_read_node :=* TLWidthWidget(acc_data_len) :=* acc_read_nodes
132-
unified_mem_write_node :=* TLWidthWidget(spad_data_len) :=* spad_write_nodes
133-
// unified_mem_write_node :=* TLWidthWidget(acc_data_len) :=* acc_write_nodes
134-
135-
val stride_by_word = false // TODO (richard): move to config
136-
137-
require(isPow2(config.sp_banks))
138-
val banks : Seq[Seq[TLManagerNode]] =
139-
if (stride_by_word) {
140-
assert(false, "TODO under construction")
141-
assert((config.sp_capacity match { case CapacityInKilobytes(kb) => kb * 1024}) ==
142-
config.sp_bank_entries * spad_data_len / max_data_len * config.sp_banks * max_data_len)
143-
(0 until config.sp_banks).map { bank =>
144-
LazyModule(new TLRAM(
145-
address = AddressSet(max_data_len * bank,
146-
((config.sp_bank_entries * spad_data_len / max_data_len - 1) * config.sp_banks + bank)
147-
* max_data_len + (max_data_len - 1)),
148-
beatBytes = max_data_len
149-
))
150-
}.map(x => Seq(x.node))
151-
} else {
152-
(0 until config.sp_banks).map { bank =>
153-
Seq(TLManagerNode(Seq(TLSlavePortParameters.v1(
154-
managers = Seq(TLSlaveParameters.v2(
155-
name = Some(f"sp_bank${bank}_read_mgr"),
156-
address = Seq(AddressSet(spad_base + (mem_depth * mem_width * bank),
157-
mem_depth * mem_width - 1)),
158-
supports = TLMasterToSlaveTransferSizes(
159-
get = TransferSizes(1, mem_width)),
160-
fifoId = Some(0)
161-
)),
162-
beatBytes = mem_width
163-
))),
164-
TLManagerNode(Seq(TLSlavePortParameters.v1(
165-
managers = Seq(TLSlaveParameters.v2(
166-
name = Some(f"sp_bank${bank}_write_mgr"),
167-
address = Seq(AddressSet(spad_base + (mem_depth * mem_width * bank),
168-
mem_depth * mem_width - 1)),
169-
supports = TLMasterToSlaveTransferSizes(
170-
putFull = TransferSizes(1, mem_width),
171-
putPartial = TransferSizes(1, mem_width)),
172-
fifoId = Some(0)
173-
)),
174-
beatBytes = mem_width
175-
))))
176-
}
177-
}
50+
// make scratchpad read and write clients, per bank
51+
// _____ ________ _______ ___ ___
52+
// / __/ |/_/_ __/ / __/ _ \/ _ | / _ \
53+
// / _/_> < / / _\ \/ ___/ __ |/ // /
54+
// /___/_/|_| /_/ /___/_/ /_/ |_/____/
55+
// ***************************************
56+
// HOW TO USE EXTERNAL SCRATCHPAD:
57+
// the scratchpad MUST BE INSTANTIATED ELSEWHERE if use_ext_tl_mem is enabled,
58+
// else elaboration will not pass. the scratchpad needs to be dual ported
59+
// and must be able to serve the entire scratchpad row (config.sp_width) in 1 cycle.
60+
// three nodes must be hooked up correctly: spad_read_nodes, spad_write_nodes, and spad.spad_writer.node
61+
// for deadlock avoidance, read and write should not be sharing a single channel anywhere until the SRAMs.
62+
// see RadianceCluster.scala for an example
63+
val spad_read_nodes = if (use_ext_tl_mem) TLClientNode(Seq.tabulate(config.sp_banks) {i =>
64+
TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(
65+
name = s"spad_read_node_$i",
66+
sourceId = IdRange(0, num_ids),
67+
visibility = Seq(AddressSet(spad_base + i * mem_width * mem_depth, mem_width * mem_depth - 1))
68+
)))
69+
}) else TLIdentityNode()
17870

179-
require(!config.sp_singleported, "external scratchpad must be dual ported")
180-
val r_xbar = TLXbar()
181-
val w_xbar = TLXbar()
182-
r_xbar :=* unified_mem_read_node
183-
w_xbar :=* unified_mem_write_node
184-
banks.foreach { mem =>
185-
require(mem.length == 2)
186-
mem.head := r_xbar
187-
mem.last := TLFragmenter(spad_data_len, spad.maxBytes) := w_xbar
188-
}
71+
val spad_write_nodes = if (use_ext_tl_mem) TLClientNode(Seq.tabulate(config.sp_banks) { i =>
72+
TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(
73+
name = s"spad_write_node_$i",
74+
sourceId = IdRange(0, num_ids),
75+
visibility = Seq(AddressSet(spad_base + i * mem_width * mem_depth, mem_width * mem_depth - 1))
76+
)))
77+
}) else TLIdentityNode()
18978

190-
banks
191-
} else Seq()
79+
// val acc_read_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.acc_banks) { i =>
80+
// TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"acc_read_node_$i", sourceId = IdRange(0, numIDs))))
81+
// }) else TLIdentityNode()
82+
// val acc_write_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.acc_banks) { i =>
83+
// TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"acc_write_node_$i", sourceId = IdRange(0, numIDs))))
84+
// }) else TLIdentityNode()
19285

19386
override lazy val module = new GemminiModule(this)
19487
override val tlNode = if (config.use_dedicated_tl_port) spad.id_node else TLIdentityNode()
@@ -204,9 +97,6 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA
20497
concurrency = 1)
20598

20699
regNode := TLFragmenter(8, 64) := stlNode
207-
208-
unified_mem_write_node := spad.spad_writer.node
209-
210100
}
211101

212102
class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
@@ -227,7 +117,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
227117
// connecting to unified TL interface
228118
val source_counters = Seq.fill(4)(Counter(outer.num_ids))
229119

230-
if (outer.create_tl_mem) {
120+
if (outer.use_ext_tl_mem) {
231121
def connect(ext_mem: ExtMemIO, bank_base: Int, req_size: Int, r_node: TLBundle, r_edge: TLEdgeOut, r_source: Counter,
232122
w_node: TLBundle, w_edge: TLEdgeOut, w_source: Counter): Unit = {
233123
r_node.a.valid := ext_mem.read_req.valid
@@ -260,83 +150,6 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
260150
r_node, r_edge, source_counters(0), w_node, w_edge, source_counters(1))
261151
}
262152

263-
outer.spad_tl_ram.foreach { case Seq(r, w) =>
264-
val mem_depth = outer.config.sp_bank_entries * outer.spad_data_len / outer.max_data_len
265-
val mem_width = outer.max_data_len
266-
267-
val mem = TwoPortSyncMem(
268-
n = mem_depth,
269-
t = UInt((mem_width * 8).W),
270-
mask_len = mem_width // byte level mask
271-
)
272-
273-
val (r_node, r_edge) = r.in.head
274-
val (w_node, w_edge) = w.in.head
275-
276-
// READ
277-
mem.io.ren := r_node.a.fire
278-
mem.io.raddr := (r_node.a.bits.address ^ outer.spad_base.U) >> log2Ceil(mem_width).U
279-
280-
val data_pipe_in = Wire(DecoupledIO(mem.io.rdata.cloneType))
281-
data_pipe_in.valid := RegNext(mem.io.ren)
282-
data_pipe_in.bits := mem.io.rdata
283-
284-
val metadata_pipe_in = Wire(DecoupledIO(new Bundle {
285-
val source = r_node.a.bits.source.cloneType
286-
val size = r_node.a.bits.size.cloneType
287-
}))
288-
metadata_pipe_in.valid := mem.io.ren
289-
metadata_pipe_in.bits.source := r_node.a.bits.source
290-
metadata_pipe_in.bits.size := r_node.a.bits.size
291-
292-
val sram_read_backup_reg = RegInit(0.U.asTypeOf(Valid(mem.io.rdata.cloneType)))
293-
294-
val data_pipe_inst = Module(new Pipeline(data_pipe_in.bits.cloneType, 1)())
295-
data_pipe_inst.io.in <> data_pipe_in
296-
val data_pipe = data_pipe_inst.io.out
297-
val metadata_pipe = Pipeline(metadata_pipe_in, 2)
298-
assert((data_pipe.valid || sram_read_backup_reg.valid) === metadata_pipe.valid)
299-
300-
// data pipe is filled, but D is not ready and SRAM read came back
301-
when (data_pipe.valid && !r_node.d.ready && data_pipe_in.valid) {
302-
assert(!data_pipe_in.ready) // we should fill backup reg only if data pipe is not enqueueing
303-
assert(!sram_read_backup_reg.valid) // backup reg should be empty
304-
assert(!metadata_pipe_in.ready) // metadata should be filled previous cycle
305-
sram_read_backup_reg.valid := true.B
306-
sram_read_backup_reg.bits := mem.io.rdata
307-
}.otherwise {
308-
assert(data_pipe_in.ready || !data_pipe_in.valid) // do not skip any response
309-
}
310-
311-
assert(metadata_pipe_in.fire || !mem.io.ren) // when requesting sram, metadata needs to be ready
312-
assert(r_node.d.fire === metadata_pipe.fire) // metadata dequeues iff D fires
313-
314-
// when D becomes ready, and data pipe has emptied, time for backup to empty
315-
when (r_node.d.ready && sram_read_backup_reg.valid && !data_pipe.valid) {
316-
sram_read_backup_reg.valid := false.B
317-
}
318-
assert(!(sram_read_backup_reg.valid && data_pipe.valid && data_pipe_in.fire)) // must empty backup before filling data pipe
319-
assert(data_pipe_in.valid === data_pipe_in.fire)
320-
321-
r_node.d.bits := r_edge.AccessAck(
322-
metadata_pipe.bits.source,
323-
metadata_pipe.bits.size,
324-
Mux(!data_pipe.valid, sram_read_backup_reg.bits, data_pipe.bits))
325-
r_node.d.valid := data_pipe.valid || sram_read_backup_reg.valid
326-
// r node A is not ready only if D is not ready and both slots filled
327-
r_node.a.ready := r_node.d.ready && !(data_pipe.valid && sram_read_backup_reg.valid)
328-
data_pipe.ready := r_node.d.ready
329-
metadata_pipe.ready := r_node.d.ready
330-
331-
// WRITE
332-
mem.io.wen := w_node.a.fire
333-
mem.io.waddr := (w_node.a.bits.address ^ outer.spad_base.U) >> log2Ceil(mem_width).U
334-
mem.io.wdata := w_node.a.bits.data
335-
mem.io.mask := w_node.a.bits.mask.asBools
336-
w_node.a.ready := w_node.d.ready// && (mem.io.waddr =/= mem.io.raddr)
337-
w_node.d.valid := w_node.a.valid
338-
w_node.d.bits := w_edge.AccessAck(w_node.a.bits)
339-
}
340153

341154
ext_mem_acc.foreach(_.foreach(x => {
342155
x.read_resp.bits := 0.U(1.W)
@@ -350,84 +163,6 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
350163
// connect(ext_mem_acc(i)(0), log2Up(outer.acc_data_len),
351164
// r_node, r_edge, source_counters(2), w_node, w_edge, source_counters(3))
352165
// }
353-
354-
// hook up read/write for general unified mem nodes
355-
{
356-
val u_out = outer.unified_mem_node.out
357-
val u_in = outer.unified_mem_node.in
358-
assert(u_out.length == 2)
359-
println(f"gemmini unified memory node has ${u_in.length} incoming client(s)")
360-
361-
val r_out = u_out.head
362-
val w_out = u_out.last
363-
364-
val in_src = TLXbar.mapInputIds(u_in.map(_._2.client))
365-
val in_src_size = in_src.map(_.end).max
366-
assert(isPow2(in_src_size)) // should be checked already, but just to be sure
367-
368-
// arbitrate all reads into one read while assigning source prefix, same for write
369-
val a_arbiter_in = (u_in zip in_src).map { case ((in_node, _), src_range) =>
370-
val in_r: DecoupledIO[TLBundleA] =
371-
WireDefault(0.U.asTypeOf(Decoupled(new TLBundleA(in_node.a.bits.params.copy(
372-
sourceBits = log2Up(in_src_size) + 1
373-
)))))
374-
val in_w: DecoupledIO[TLBundleA] = WireDefault(0.U.asTypeOf(in_r.cloneType))
375-
376-
val req_is_read = in_node.a.bits.opcode === TLMessages.Get
377-
378-
(Seq(in_r.bits.user, in_r.bits.address, in_r.bits.opcode, in_r.bits.size,
379-
in_r.bits.mask, in_r.bits.param, in_r.bits.data)
380-
zip Seq(in_node.a.bits.user, in_node.a.bits.address, in_node.a.bits.opcode, in_node.a.bits.size,
381-
in_node.a.bits.mask, in_node.a.bits.param, in_node.a.bits.data))
382-
.foreach { case (x, y) => x := y }
383-
in_r.bits.source := in_node.a.bits.source | src_range.start.U | Mux(req_is_read, 0.U, in_src_size.U)
384-
in_w.bits := in_r.bits
385-
386-
in_r.valid := in_node.a.valid && req_is_read
387-
in_w.valid := in_node.a.valid && !req_is_read
388-
in_node.a.ready := Mux(req_is_read, in_r.ready, in_w.ready)
389-
390-
(in_r, in_w)
391-
}
392-
// we cannot use round robin because it might reorder requests, even from the same client
393-
val (a_arbiter_in_r_nodes, a_arbiter_in_w_nodes) = a_arbiter_in.unzip
394-
TLArbiter.lowest(r_out._2, r_out._1.a, a_arbiter_in_r_nodes:_*)
395-
TLArbiter.lowest(w_out._2, w_out._1.a, a_arbiter_in_w_nodes:_*)
396-
397-
def trim(id: UInt, size: Int): UInt = if (size <= 1) 0.U else id(log2Ceil(size)-1, 0) // from Xbar
398-
// for each unified mem node client, arbitrate read/write responses on d channel
399-
(u_in zip in_src).zipWithIndex.foreach { case (((in_node, in_edge), src_range), i) =>
400-
// assign d channel back based on source, invalid if source prefix mismatch
401-
val resp = Seq(r_out._1.d, w_out._1.d)
402-
val source_match = resp.zipWithIndex.map { case (r, i) =>
403-
(r.bits.source(r.bits.source.getWidth - 1) === i.U(1.W)) && // MSB indicates read(0)/write(1)
404-
src_range.contains(trim(r.bits.source, in_src_size))
405-
}
406-
val d_arbiter_in = resp.map(r => WireDefault(
407-
0.U.asTypeOf(Decoupled(new TLBundleD(r.bits.params.copy(
408-
sourceBits = in_node.d.bits.source.getWidth,
409-
sizeBits = in_node.d.bits.size.getWidth
410-
))))
411-
))
412-
413-
(d_arbiter_in lazyZip resp lazyZip source_match).foreach { case (arb_in, r, sm) =>
414-
(Seq(arb_in.bits.user, arb_in.bits.opcode, arb_in.bits.data, arb_in.bits.param,
415-
arb_in.bits.sink, arb_in.bits.denied, arb_in.bits.corrupt)
416-
zip Seq(r.bits.user, r.bits.opcode, r.bits.data, r.bits.param,
417-
r.bits.sink, r.bits.denied, r.bits.corrupt))
418-
.foreach { case (x, y) => x := y }
419-
arb_in.bits.source := trim(r.bits.source, 1 << in_node.d.bits.source.getWidth) // we can trim b/c isPow2(prefix)
420-
arb_in.bits.size := trim(r.bits.size, 1 << in_node.d.bits.size.getWidth) // FIXME: check truncation
421-
422-
arb_in.valid := r.valid && sm
423-
r.ready := arb_in.ready
424-
}
425-
426-
TLArbiter.robin(in_edge, in_node.d, d_arbiter_in:_*)
427-
}
428-
429-
}
430-
431166
} else if (use_shared_ext_mem) {
432167
ext_mem_io.foreach(_ <> outer.spad.module.io.ext_mem.get)
433168
}

0 commit comments

Comments
 (0)