Skip to content

Commit

Permalink
move ext scratchpad sram instantiation outside
Browse files Browse the repository at this point in the history
  • Loading branch information
richardyrh committed Mar 26, 2024
1 parent 5a10e96 commit a92e231
Showing 1 changed file with 35 additions and 300 deletions.
335 changes: 35 additions & 300 deletions src/main/scala/gemmini/Controller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,30 +35,9 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA
val xLen = p(XLen)
val spad = LazyModule(new Scratchpad(config))

val create_tl_mem = config.use_shared_ext_mem && config.use_tl_ext_mem

val use_ext_tl_mem = config.use_shared_ext_mem && config.use_tl_ext_mem
val num_ids = 32 // TODO (richard): move to config
val spad_base = config.tl_ext_mem_base

val unified_mem_read_node = TLIdentityNode()
val spad_read_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.sp_banks) {i =>
TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"spad_read_node_$i", sourceId = IdRange(0, num_ids))))
}) else TLIdentityNode()
// val acc_read_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.acc_banks) { i =>
// TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"acc_read_node_$i", sourceId = IdRange(0, numIDs))))
// }) else TLIdentityNode()

val unified_mem_write_node = TLIdentityNode()
val spad_write_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.sp_banks) { i =>
TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"spad_write_node_$i", sourceId = IdRange(0, num_ids))))
}) else TLIdentityNode()

// val spad_dma_write_node = TLClientNode(Seq(
// TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"spad_dma_write_node", sourceId = IdRange(0, num_ids))))))
// val acc_write_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.acc_banks) { i =>
// TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"acc_write_node_$i", sourceId = IdRange(0, numIDs))))
// }) else TLIdentityNode()

val spad_data_len = config.sp_width / 8
val acc_data_len = config.sp_width / config.inputType.getWidth * config.accType.getWidth / 8
val max_data_len = spad_data_len // max acc_data_len
Expand All @@ -68,127 +47,41 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA
require(mem_depth * mem_width * config.sp_banks == 1 << 14, f"memory size is ${mem_depth}, ${mem_width}")
println(f"unified shared memory size: ${mem_depth}x${mem_width}x${config.sp_banks}")

// this node accepts both read and write requests,
// splits & arbitrates them into one client node per type of operation
val unified_mem_node = TLNexusNode(
clientFn = { seq =>
val in_mapping = TLXbar.mapInputIds(seq)
val read_src_range = IdRange(in_mapping.map(_.start).min, in_mapping.map(_.end).max)
assert((read_src_range.start == 0) && isPow2(read_src_range.end))
val write_src_range = read_src_range.shift(read_src_range.size)

seq(0).v1copy(
echoFields = BundleField.union(seq.flatMap(_.echoFields)),
requestFields = BundleField.union(seq.flatMap(_.requestFields)),
responseKeys = seq.flatMap(_.responseKeys).distinct,
minLatency = seq.map(_.minLatency).min,
clients = Seq(
TLMasterParameters.v1(
name = "unified_mem_read_client",
sourceId = read_src_range,
supportsProbe = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
supportsGet = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
supportsPutFull = TransferSizes.none,
supportsPutPartial = TransferSizes.none
),
TLMasterParameters.v1(
name = "unified_mem_write_client",
sourceId = write_src_range,
supportsProbe = TransferSizes.mincover(
seq.map(_.anyEmitClaims.putFull) ++seq.map(_.anyEmitClaims.putPartial)),
supportsGet = TransferSizes.none,
supportsPutFull = TransferSizes.mincover(seq.map(_.anyEmitClaims.putFull)),
supportsPutPartial = TransferSizes.mincover(seq.map(_.anyEmitClaims.putPartial))
)
)
)
},
managerFn = { seq =>
// val fifoIdFactory = TLXbar.relabeler()
seq(0).v1copy(
responseFields = BundleField.union(seq.flatMap(_.responseFields)),
requestKeys = seq.flatMap(_.requestKeys).distinct,
minLatency = seq.map(_.minLatency).min,
endSinkId = TLXbar.mapOutputIds(seq).map(_.end).max,
managers = Seq(TLSlaveParameters.v2(
name = Some(f"unified_mem_manager"),
address = Seq(AddressSet(spad_base, mem_depth * mem_width * config.sp_banks - 1)),
supports = TLMasterToSlaveTransferSizes(
get = TransferSizes(1, mem_width),
putFull = TransferSizes(1, mem_width),
putPartial = TransferSizes(1, mem_width)),
fifoId = Some(0)
))
)
}
)

unified_mem_read_node := TLWidthWidget(spad_data_len) := unified_mem_node
unified_mem_write_node := TLWidthWidget(spad_data_len) := unified_mem_node

val spad_tl_ram : Seq[Seq[TLManagerNode]] = if (config.use_shared_ext_mem && config.use_tl_ext_mem) {
unified_mem_read_node :=* TLWidthWidget(spad_data_len) :=* spad_read_nodes
// unified_mem_read_node :=* TLWidthWidget(acc_data_len) :=* acc_read_nodes
unified_mem_write_node :=* TLWidthWidget(spad_data_len) :=* spad_write_nodes
// unified_mem_write_node :=* TLWidthWidget(acc_data_len) :=* acc_write_nodes

val stride_by_word = false // TODO (richard): move to config

require(isPow2(config.sp_banks))
val banks : Seq[Seq[TLManagerNode]] =
if (stride_by_word) {
assert(false, "TODO under construction")
assert((config.sp_capacity match { case CapacityInKilobytes(kb) => kb * 1024}) ==
config.sp_bank_entries * spad_data_len / max_data_len * config.sp_banks * max_data_len)
(0 until config.sp_banks).map { bank =>
LazyModule(new TLRAM(
address = AddressSet(max_data_len * bank,
((config.sp_bank_entries * spad_data_len / max_data_len - 1) * config.sp_banks + bank)
* max_data_len + (max_data_len - 1)),
beatBytes = max_data_len
))
}.map(x => Seq(x.node))
} else {
(0 until config.sp_banks).map { bank =>
Seq(TLManagerNode(Seq(TLSlavePortParameters.v1(
managers = Seq(TLSlaveParameters.v2(
name = Some(f"sp_bank${bank}_read_mgr"),
address = Seq(AddressSet(spad_base + (mem_depth * mem_width * bank),
mem_depth * mem_width - 1)),
supports = TLMasterToSlaveTransferSizes(
get = TransferSizes(1, mem_width)),
fifoId = Some(0)
)),
beatBytes = mem_width
))),
TLManagerNode(Seq(TLSlavePortParameters.v1(
managers = Seq(TLSlaveParameters.v2(
name = Some(f"sp_bank${bank}_write_mgr"),
address = Seq(AddressSet(spad_base + (mem_depth * mem_width * bank),
mem_depth * mem_width - 1)),
supports = TLMasterToSlaveTransferSizes(
putFull = TransferSizes(1, mem_width),
putPartial = TransferSizes(1, mem_width)),
fifoId = Some(0)
)),
beatBytes = mem_width
))))
}
}
// make scratchpad read and write clients, per bank
// _____ ________ _______ ___ ___
// / __/ |/_/_ __/ / __/ _ \/ _ | / _ \
// / _/_> < / / _\ \/ ___/ __ |/ // /
// /___/_/|_| /_/ /___/_/ /_/ |_/____/
// ***************************************
// HOW TO USE EXTERNAL SCRATCHPAD:
// the scratchpad MUST BE INSTANTIATED ELSEWHERE if use_ext_tl_mem is enabled,
// else elaboration will not pass. the scratchpad needs to be dual ported
// and must be able to serve the entire scratchpad row (config.sp_width) in 1 cycle.
// three nodes must be hooked up correctly: spad_read_nodes, spad_write_nodes, and spad.spad_writer.node
// for deadlock avoidance, read and write should not be sharing a single channel anywhere until the SRAMs.
// see RadianceCluster.scala for an example
val spad_read_nodes = if (use_ext_tl_mem) TLClientNode(Seq.tabulate(config.sp_banks) {i =>
TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(
name = s"spad_read_node_$i",
sourceId = IdRange(0, num_ids),
visibility = Seq(AddressSet(spad_base + i * mem_width * mem_depth, mem_width * mem_depth - 1))
)))
}) else TLIdentityNode()

require(!config.sp_singleported, "external scratchpad must be dual ported")
val r_xbar = TLXbar()
val w_xbar = TLXbar()
r_xbar :=* unified_mem_read_node
w_xbar :=* unified_mem_write_node
banks.foreach { mem =>
require(mem.length == 2)
mem.head := r_xbar
mem.last := TLFragmenter(spad_data_len, spad.maxBytes) := w_xbar
}
val spad_write_nodes = if (use_ext_tl_mem) TLClientNode(Seq.tabulate(config.sp_banks) { i =>
TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(
name = s"spad_write_node_$i",
sourceId = IdRange(0, num_ids),
visibility = Seq(AddressSet(spad_base + i * mem_width * mem_depth, mem_width * mem_depth - 1))
)))
}) else TLIdentityNode()

banks
} else Seq()
// val acc_read_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.acc_banks) { i =>
// TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"acc_read_node_$i", sourceId = IdRange(0, numIDs))))
// }) else TLIdentityNode()
// val acc_write_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.acc_banks) { i =>
// TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"acc_write_node_$i", sourceId = IdRange(0, numIDs))))
// }) else TLIdentityNode()

override lazy val module = new GemminiModule(this)
override val tlNode = if (config.use_dedicated_tl_port) spad.id_node else TLIdentityNode()
Expand All @@ -204,9 +97,6 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA
concurrency = 1)

regNode := TLFragmenter(8, 64) := stlNode

unified_mem_write_node := spad.spad_writer.node

}

class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
Expand All @@ -227,7 +117,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
// connecting to unified TL interface
val source_counters = Seq.fill(4)(Counter(outer.num_ids))

if (outer.create_tl_mem) {
if (outer.use_ext_tl_mem) {
def connect(ext_mem: ExtMemIO, bank_base: Int, req_size: Int, r_node: TLBundle, r_edge: TLEdgeOut, r_source: Counter,
w_node: TLBundle, w_edge: TLEdgeOut, w_source: Counter): Unit = {
r_node.a.valid := ext_mem.read_req.valid
Expand Down Expand Up @@ -260,83 +150,6 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
r_node, r_edge, source_counters(0), w_node, w_edge, source_counters(1))
}

outer.spad_tl_ram.foreach { case Seq(r, w) =>
val mem_depth = outer.config.sp_bank_entries * outer.spad_data_len / outer.max_data_len
val mem_width = outer.max_data_len

val mem = TwoPortSyncMem(
n = mem_depth,
t = UInt((mem_width * 8).W),
mask_len = mem_width // byte level mask
)

val (r_node, r_edge) = r.in.head
val (w_node, w_edge) = w.in.head

// READ
mem.io.ren := r_node.a.fire
mem.io.raddr := (r_node.a.bits.address ^ outer.spad_base.U) >> log2Ceil(mem_width).U

val data_pipe_in = Wire(DecoupledIO(mem.io.rdata.cloneType))
data_pipe_in.valid := RegNext(mem.io.ren)
data_pipe_in.bits := mem.io.rdata

val metadata_pipe_in = Wire(DecoupledIO(new Bundle {
val source = r_node.a.bits.source.cloneType
val size = r_node.a.bits.size.cloneType
}))
metadata_pipe_in.valid := mem.io.ren
metadata_pipe_in.bits.source := r_node.a.bits.source
metadata_pipe_in.bits.size := r_node.a.bits.size

val sram_read_backup_reg = RegInit(0.U.asTypeOf(Valid(mem.io.rdata.cloneType)))

val data_pipe_inst = Module(new Pipeline(data_pipe_in.bits.cloneType, 1)())
data_pipe_inst.io.in <> data_pipe_in
val data_pipe = data_pipe_inst.io.out
val metadata_pipe = Pipeline(metadata_pipe_in, 2)
assert((data_pipe.valid || sram_read_backup_reg.valid) === metadata_pipe.valid)

// data pipe is filled, but D is not ready and SRAM read came back
when (data_pipe.valid && !r_node.d.ready && data_pipe_in.valid) {
assert(!data_pipe_in.ready) // we should fill backup reg only if data pipe is not enqueueing
assert(!sram_read_backup_reg.valid) // backup reg should be empty
assert(!metadata_pipe_in.ready) // metadata should be filled previous cycle
sram_read_backup_reg.valid := true.B
sram_read_backup_reg.bits := mem.io.rdata
}.otherwise {
assert(data_pipe_in.ready || !data_pipe_in.valid) // do not skip any response
}

assert(metadata_pipe_in.fire || !mem.io.ren) // when requesting sram, metadata needs to be ready
assert(r_node.d.fire === metadata_pipe.fire) // metadata dequeues iff D fires

// when D becomes ready, and data pipe has emptied, time for backup to empty
when (r_node.d.ready && sram_read_backup_reg.valid && !data_pipe.valid) {
sram_read_backup_reg.valid := false.B
}
assert(!(sram_read_backup_reg.valid && data_pipe.valid && data_pipe_in.fire)) // must empty backup before filling data pipe
assert(data_pipe_in.valid === data_pipe_in.fire)

r_node.d.bits := r_edge.AccessAck(
metadata_pipe.bits.source,
metadata_pipe.bits.size,
Mux(!data_pipe.valid, sram_read_backup_reg.bits, data_pipe.bits))
r_node.d.valid := data_pipe.valid || sram_read_backup_reg.valid
// r node A is not ready only if D is not ready and both slots filled
r_node.a.ready := r_node.d.ready && !(data_pipe.valid && sram_read_backup_reg.valid)
data_pipe.ready := r_node.d.ready
metadata_pipe.ready := r_node.d.ready

// WRITE
mem.io.wen := w_node.a.fire
mem.io.waddr := (w_node.a.bits.address ^ outer.spad_base.U) >> log2Ceil(mem_width).U
mem.io.wdata := w_node.a.bits.data
mem.io.mask := w_node.a.bits.mask.asBools
w_node.a.ready := w_node.d.ready// && (mem.io.waddr =/= mem.io.raddr)
w_node.d.valid := w_node.a.valid
w_node.d.bits := w_edge.AccessAck(w_node.a.bits)
}

ext_mem_acc.foreach(_.foreach(x => {
x.read_resp.bits := 0.U(1.W)
Expand All @@ -350,84 +163,6 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
// connect(ext_mem_acc(i)(0), log2Up(outer.acc_data_len),
// r_node, r_edge, source_counters(2), w_node, w_edge, source_counters(3))
// }

// hook up read/write for general unified mem nodes
{
val u_out = outer.unified_mem_node.out
val u_in = outer.unified_mem_node.in
assert(u_out.length == 2)
println(f"gemmini unified memory node has ${u_in.length} incoming client(s)")

val r_out = u_out.head
val w_out = u_out.last

val in_src = TLXbar.mapInputIds(u_in.map(_._2.client))
val in_src_size = in_src.map(_.end).max
assert(isPow2(in_src_size)) // should be checked already, but just to be sure

// arbitrate all reads into one read while assigning source prefix, same for write
val a_arbiter_in = (u_in zip in_src).map { case ((in_node, _), src_range) =>
val in_r: DecoupledIO[TLBundleA] =
WireDefault(0.U.asTypeOf(Decoupled(new TLBundleA(in_node.a.bits.params.copy(
sourceBits = log2Up(in_src_size) + 1
)))))
val in_w: DecoupledIO[TLBundleA] = WireDefault(0.U.asTypeOf(in_r.cloneType))

val req_is_read = in_node.a.bits.opcode === TLMessages.Get

(Seq(in_r.bits.user, in_r.bits.address, in_r.bits.opcode, in_r.bits.size,
in_r.bits.mask, in_r.bits.param, in_r.bits.data)
zip Seq(in_node.a.bits.user, in_node.a.bits.address, in_node.a.bits.opcode, in_node.a.bits.size,
in_node.a.bits.mask, in_node.a.bits.param, in_node.a.bits.data))
.foreach { case (x, y) => x := y }
in_r.bits.source := in_node.a.bits.source | src_range.start.U | Mux(req_is_read, 0.U, in_src_size.U)
in_w.bits := in_r.bits

in_r.valid := in_node.a.valid && req_is_read
in_w.valid := in_node.a.valid && !req_is_read
in_node.a.ready := Mux(req_is_read, in_r.ready, in_w.ready)

(in_r, in_w)
}
// we cannot use round robin because it might reorder requests, even from the same client
val (a_arbiter_in_r_nodes, a_arbiter_in_w_nodes) = a_arbiter_in.unzip
TLArbiter.lowest(r_out._2, r_out._1.a, a_arbiter_in_r_nodes:_*)
TLArbiter.lowest(w_out._2, w_out._1.a, a_arbiter_in_w_nodes:_*)

def trim(id: UInt, size: Int): UInt = if (size <= 1) 0.U else id(log2Ceil(size)-1, 0) // from Xbar
// for each unified mem node client, arbitrate read/write responses on d channel
(u_in zip in_src).zipWithIndex.foreach { case (((in_node, in_edge), src_range), i) =>
// assign d channel back based on source, invalid if source prefix mismatch
val resp = Seq(r_out._1.d, w_out._1.d)
val source_match = resp.zipWithIndex.map { case (r, i) =>
(r.bits.source(r.bits.source.getWidth - 1) === i.U(1.W)) && // MSB indicates read(0)/write(1)
src_range.contains(trim(r.bits.source, in_src_size))
}
val d_arbiter_in = resp.map(r => WireDefault(
0.U.asTypeOf(Decoupled(new TLBundleD(r.bits.params.copy(
sourceBits = in_node.d.bits.source.getWidth,
sizeBits = in_node.d.bits.size.getWidth
))))
))

(d_arbiter_in lazyZip resp lazyZip source_match).foreach { case (arb_in, r, sm) =>
(Seq(arb_in.bits.user, arb_in.bits.opcode, arb_in.bits.data, arb_in.bits.param,
arb_in.bits.sink, arb_in.bits.denied, arb_in.bits.corrupt)
zip Seq(r.bits.user, r.bits.opcode, r.bits.data, r.bits.param,
r.bits.sink, r.bits.denied, r.bits.corrupt))
.foreach { case (x, y) => x := y }
arb_in.bits.source := trim(r.bits.source, 1 << in_node.d.bits.source.getWidth) // we can trim b/c isPow2(prefix)
arb_in.bits.size := trim(r.bits.size, 1 << in_node.d.bits.size.getWidth) // FIXME: check truncation

arb_in.valid := r.valid && sm
r.ready := arb_in.ready
}

TLArbiter.robin(in_edge, in_node.d, d_arbiter_in:_*)
}

}

} else if (use_shared_ext_mem) {
ext_mem_io.foreach(_ <> outer.spad.module.io.ext_mem.get)
}
Expand Down

0 comments on commit a92e231

Please sign in to comment.