Skip to content

Commit d0ed63a

Browse files
Amanda ShiAmanda Shi
authored andcommitted
scale mem mmio debug
1 parent 11e9ae6 commit d0ed63a

File tree

8 files changed

+63
-49
lines changed

8 files changed

+63
-49
lines changed

src/main/scala/gemmini/AccumulatorMem.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class AccumulatorWriteReq[T <: Data: Arithmetic](n: Int, t: Vec[Vec[T]]) extends
4747

4848

4949
class AccumulatorMemIO [T <: Data: Arithmetic, U <: Data](n: Int, t: Vec[Vec[T]], scale_t: U,
50-
acc_sub_banks: Int, use_shared_ext_mem: Boolean, use_mx_scaling: Boolean, meshRows: Int, tileRows: Int, sramLineSizeInBytes: Int
50+
acc_sub_banks: Int, use_shared_ext_mem: Boolean, use_mx_scaling: Boolean, meshRows: Int, tileRows: Int, bankWidthBits: Int
5151
) extends Bundle {
5252
val read = Flipped(new AccumulatorReadIO(n, t, scale_t))
5353
val write = Flipped(Decoupled(new AccumulatorWriteReq(n, t)))
@@ -66,10 +66,10 @@ class AccumulatorMemIO [T <: Data: Arithmetic, U <: Data](n: Int, t: Vec[Vec[T]]
6666

6767
val dataType = Input(UInt(2.W)) //this is the input mxformat datatype
6868
val scale_mem_write_act = if (use_mx_scaling) {
69-
Some(Flipped(Decoupled(new ScalingFactorWriteReq(9, 256))))
69+
Some(Flipped(Decoupled(new ScalingFactorWriteReq(9, bankWidthBits))))
7070
} else None
7171
val scale_mem_write_w = if (use_mx_scaling) {
72-
Some(Flipped(Decoupled(new ScalingFactorWriteReq(9, 256))))
72+
Some(Flipped(Decoupled(new ScalingFactorWriteReq(9, bankWidthBits))))
7373
} else None
7474
val scaleMemCntl = if (use_mx_scaling) {
7575
Some(Input(new ScalingFactorCntl(meshRows * tileRows)))
@@ -127,12 +127,17 @@ class AccumulatorMem[T <: Data, U <: Data](
127127
import ev._
128128

129129
// TODO unify this with TwoPortSyncMemIO
130-
val io = IO(new AccumulatorMemIO(n, t, scale_t, acc_sub_banks, use_shared_ext_mem, use_mx_scaling, meshRows, tileRows, scale_mem.get.sramLineSizeInBytes))
130+
val io = IO(new AccumulatorMemIO(n, t, scale_t, acc_sub_banks, use_shared_ext_mem, use_mx_scaling, meshRows, tileRows, scale_mem.get.bankWidthBits))
131131

132132
val scaleFactorMem = scale_mem.map { conf =>
133+
// println(s"[ScalingFactorMem Config]")
134+
// println(s" depth = ${conf.depth}")
135+
// println(s" subbankLineSizeInBytes = ${conf.subbankLineSizeInBytes}")
136+
// println(s" bankWidthBits = ${conf.bankWidthBits}")
137+
// println(s" numBanks = ${conf.numBanks}")
133138
Module(new ScalingFactorMem(
134139
depth = conf.depth,
135-
bankWidth = conf.bankWidthBits,
140+
sramWidth = conf.subbankLineSizeInBytes*8,
136141
actOutputScalingWidth = 8,
137142
numBanks = conf.numBanks,
138143
testConfig = testConfig,

src/main/scala/gemmini/ConfigsFP.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ object GemminiMxFPConfigs {
274274

275275
// 16x16 mesh with varying precisions
276276
scaleMem_data_width = 128,
277-
scaleMem_bank_entries = 256,
277+
scaleMem_bank_entries = 8192,
278278
scaleSize = 32,
279279
enable_lut = true,
280280
mvin_scale_args = None,
@@ -305,6 +305,7 @@ object GemminiMxFPConfigs {
305305

306306
num_counter = 8,
307307
requantizer = Some(GemminiRequantizerConfig(
308+
baseAddr = 0x10000000L,
308309
numInputLanes = 64,
309310
numOutputLanes = 32,
310311
gpuMaxFactor = 2,
@@ -315,8 +316,10 @@ object GemminiMxFPConfigs {
315316
outputIdBits = 3
316317
)),
317318
scale_mem = Some(GemminiScalingFactorMemConfig(
319+
baseAddr = 0x10000000L + 0x8000,
318320
sizeInBytes = 16 << 10,
319-
sramLineSizeInBytes = 256 / 8,
321+
subbankLineSizeInBytes = 16,
322+
subbanksPerBank = 2,
320323
numBanks = 8,
321324
))
322325
)

src/main/scala/gemmini/Controller.scala

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -217,32 +217,34 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
217217
val lut0 = Flipped(Decoupled(new QuantLutWriteBundle(l)))
218218
val lut1 = Flipped(Decoupled(new QuantLutWriteBundle(l)))
219219
val lut2 = Flipped(Decoupled(new QuantLutWriteBundle(l)))
220+
//val scaleFactorOut = Decoupled(new ScalingFactorWriteReq(scaleMem_addr_width, scaleMem_data_width))
220221
})
221222

222-
if (!outer.config.testConfig) {
223-
mx_io.scale_mem_write_w <> spad.module.io.scale_mem_write_w.get
224-
mx_io.scale_mem_write_act <> spad.module.io.scale_mem_write_act.get
225-
}
226-
223+
224+
spad.module.io.scale_mem_write_w.get <> mx_io.scale_mem_write_w
225+
spad.module.io.scale_mem_write_act.get <> mx_io.scale_mem_write_act
226+
227+
228+
//mx_io.scaleFactorOut <> mx_requantizer.get.io.scaleMem_write
227229
mx_io.requant_out <> mx_requantizer.get.io.requant_data_out
228230
mx_requantizer.get.io.lut0_write <> mx_io.lut0
229231
mx_requantizer.get.io.lut1_write <> mx_io.lut1
230232
mx_requantizer.get.io.lut2_write <> mx_io.lut2
231233

232-
Seq(mx_io.requant_in_gpu, mx_io.requant_out, mx_io.lut0, mx_io.lut1, mx_io.lut2).foreach(dontTouch(_))
234+
Seq(mx_io.requant_in_gpu, mx_io.requant_out, mx_io.lut0, mx_io.lut1, mx_io.lut2, mx_io.scale_mem_write_w, mx_io.scale_mem_write_act).foreach(dontTouch(_))
233235
//Seq( mx_io.requant_out).foreach(dontTouch(_))
234236
mx_io
235237
}
236238

237239

238-
spad.module.io.scale_mem_write_act.foreach { ch =>
239-
ch.valid := false.B
240-
ch.bits := DontCare
241-
}
242-
spad.module.io.scale_mem_write_w.foreach { ch =>
243-
ch.valid := false.B
244-
ch.bits := DontCare
245-
}
240+
// spad.module.io.scale_mem_write_act.foreach { ch =>
241+
// ch.valid := false.B
242+
// ch.bits := DontCare
243+
// }
244+
// spad.module.io.scale_mem_write_w.foreach { ch =>
245+
// ch.valid := false.B
246+
// ch.bits := DontCare
247+
// }
246248

247249
val lut_deprojected_data = Wire(Vec(sp_banks, new ScratchpadReadIO(sp_bank_entries, sp_width)))
248250
lut_deprojected_data := 0.U.asTypeOf(lut_deprojected_data)

src/main/scala/gemmini/ExecuteController.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,12 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
127127
val scale_mem_mvout_base_addr_act = RegInit(0.U(33.W))
128128

129129
when(functs(0) === CONFIG_SCALE_MEM) {
130-
val addr = rs1s(0).asTypeOf(new ConfigScaleMemRs1)
131-
val addr_direction = addr.mem_direction
132-
when(addr_direction === 0.U) { // mvin
133-
val act_scale_address_rs1 = addr.mem_address
134-
scale_mem_mvin_base_addr_act := act_scale_address_rs1
135-
scale_mem_mvin_base_addr_w := act_scale_address_rs1 + (config.scale_mem.get.sizeInBytes >> 1).U
136-
}.elsewhen(addr_direction === 1.U) { // mvout
137-
scale_mem_mvout_base_addr_act := addr.mem_address
130+
val direction = rs2s(0)(63)
131+
when(direction === 1.U) { // mvin
132+
scale_mem_mvin_base_addr_act := rs1s(0)
133+
scale_mem_mvin_base_addr_w := rs1s(0) + (config.scale_mem.get.sizeInBytes >> 1).U
134+
}.elsewhen(direction === 0.U) { // mvout
135+
scale_mem_mvout_base_addr_act := rs1s(0)
138136
}
139137
}
140138
io.scale_mem_mvout_base_addr_act := scale_mem_mvout_base_addr_act

src/main/scala/gemmini/GemminiConfigs.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
5757

5858

5959
scaleMem_data_width: Int = 128,
60-
scaleMem_bank_entries: Int = 256,
60+
scaleMem_bank_entries: Int = 8192,
6161
scaleSize: Int = 32,
6262

6363
dma_maxbytes: Int = 64, // TODO get this from cacheblockbytes

src/main/scala/gemmini/MxConfigFragments.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ import chisel3.util._
66
case class GemminiScalingFactorMemConfig(
77
baseAddr: BigInt = 0x10000000L,
88
sizeInBytes: BigInt = 16 << 10,
9-
sramLineSizeInBytes: Int = 32,
9+
subbankLineSizeInBytes: Int = 16,
10+
subbanksPerBank: Int = 2,
1011
numBanks: Int = 8,
1112
) {
12-
def depth: Int = (sizeInBytes / sramLineSizeInBytes / numBanks).toInt
13-
def bankWidthBits = sramLineSizeInBytes * 8
13+
def depth: Int = (sizeInBytes / (subbankLineSizeInBytes) / numBanks).toInt
14+
def bankWidthBytes = subbankLineSizeInBytes * subbanksPerBank
15+
def bankWidthBits = bankWidthBytes * 8
1416
def addrBits = log2Ceil(sizeInBytes)
15-
def lineOffsetBits = log2Ceil(sramLineSizeInBytes)
17+
def lineOffsetBits = log2Ceil(bankWidthBytes)
1618
}
1719

1820
case class GemminiRequantizerConfig(
@@ -47,12 +49,13 @@ object RequantizerDataType extends ChiselEnum {
4749
}
4850
}
4951

52+
5053
class ScalingFactorWriteReq(addrWidth: Int, dataWidth: Int) extends Bundle {
5154
val addr = UInt(addrWidth.W)
5255
val data = UInt(dataWidth.W)
5356
def this(config: GemminiScalingFactorMemConfig) = {
5457
// writes two interleaved banks at once
55-
this(config.addrBits, config.bankWidthBits * 2)
58+
this(config.addrBits, config.bankWidthBits)
5659
}
5760
}
5861

src/main/scala/gemmini/MxRequantizer.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ class MxRequantizer[T <: Data: Arithmetic](
289289
// }.otherwise {
290290
// quantLut.io.lut_write.ready := false.B
291291
// }
292-
292+
293293
quantLut.io.lut_write_weight <> io.lut0_write
294294
quantLut.io.lut_write_act_in <> io.lut1_write
295295
quantLut.io.lut_write_act_out <> io.lut2_write
@@ -330,7 +330,6 @@ class MxRequantizer[T <: Data: Arithmetic](
330330

331331
val scale_write_counter = RegInit(0.U(log2Ceil(scaleSize).W))
332332
val scale_buffer_full = RegInit(false.B)
333-
334333
when(should_compute) {
335334
for (i <- 0 until scaleSize) {
336335
when(i.U === scale_write_counter) {
@@ -342,8 +341,12 @@ class MxRequantizer[T <: Data: Arithmetic](
342341
scale_buffer_full := true.B
343342
}.otherwise {
344343
scale_write_counter := scale_write_counter + 1.U
345-
scale_buffer_full := false.B
344+
when(io.scaleMem_write.ready){
345+
scale_buffer_full := false.B
346+
}
346347
}
348+
}.elsewhen(io.scaleMem_write.ready){
349+
scale_buffer_full := false.B
347350
}
348351

349352
when(scale_buffer_full) {

src/main/scala/gemmini/ScaleFactorMem.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class ScalingFactorMemIO(addrWidth: Int, dataWidth: Int, numRows: Int, numCols:
2828

2929
class ScalingFactorMem(
3030
depth: Int = 128,
31-
bankWidth: Int = 128,
31+
sramWidth: Int = 128,
3232
actOutputScalingWidth: Int = 8,
3333
numBanks: Int = 8,
3434
testConfig: Boolean = false ,
@@ -37,12 +37,12 @@ class ScalingFactorMem(
3737
) extends Module {
3838

3939
val rowAddrWidth = log2Ceil(depth)
40-
val bytesPerBank = bankWidth / 8
40+
val bytesPerBank = sramWidth / 8
4141
val AddrWidth = rowAddrWidth + log2Ceil(numBanks)
4242
val bankaddressWidth = log2Ceil(numBanks)
4343
val totalScales = 32
4444
val counterWidth = log2Ceil(totalScales)
45-
val writeDataWidth = bankWidth * 2
45+
val writeDataWidth = sramWidth * 2
4646
val io = IO(new ScalingFactorMemIO(
4747
AddrWidth,
4848
writeDataWidth,
@@ -77,7 +77,7 @@ class ScalingFactorMem(
7777
val weight_buffer_0_read_enable = RegInit(false.B)
7878
val weight_buffer_1_read_enable = RegInit(false.B)
7979
val weight_write_counter = RegInit(0.U(8.W))
80-
val write_row_addr_w = io.scale_mem_write_w.bits.addr + (write_baseAddr_w >> (log2Ceil(2*meshRows*tileRows)))
80+
val write_row_addr_w = io.scale_mem_write_w.bits.addr
8181

8282
val weight_buffer_write_full = RegInit(false.B)
8383
when(io.scale_mem_write_w.fire) {
@@ -110,7 +110,7 @@ class ScalingFactorMem(
110110
val act_buffer_0_read_enable = RegInit(false.B)
111111
val act_buffer_1_read_enable = RegInit(false.B)
112112
val act_write_counter = RegInit(0.U(8.W))
113-
val write_row_addr_act = io.scale_mem_write_act.bits.addr + (write_baseAddr_act >> (log2Ceil(2*meshRows*tileRows)))
113+
val write_row_addr_act = io.scale_mem_write_act.bits.addr
114114
when(io.scale_mem_write_act.fire) {
115115
val write_bytes_low = io.scale_mem_write_act.bits.data(bytesPerBank * 8 - 1, 0).asTypeOf(bankDataT)
116116
val write_bytes_high = io.scale_mem_write_act.bits.data(bytesPerBank * 2 * 8 - 1, bytesPerBank * 8).asTypeOf(bankDataT)
@@ -142,17 +142,17 @@ class ScalingFactorMem(
142142
when(io.read_req.fire && io.read_req.bits.scaling_enable){
143143
act_read_buffer_select := ~act_read_buffer_select
144144
weight_read_buffer_select := ~weight_read_buffer_select
145-
when(act_buffer_0_read_enable && ((act_write_counter === read_row_addr))){
145+
when(act_buffer_0_read_enable && ((write_row_addr_act === read_row_addr))){
146146
act_buffer_0_read_enable := false.B
147147
}
148-
when(act_buffer_1_read_enable && ((act_write_counter === read_row_addr))){
148+
when(act_buffer_1_read_enable && ((write_row_addr_act === read_row_addr))){
149149
act_buffer_1_read_enable := false.B
150150
}
151151

152-
when(weight_buffer_0_read_enable && ((weight_write_counter === read_row_addr))){
152+
when(weight_buffer_0_read_enable && ((write_row_addr_w === read_row_addr))){
153153
weight_buffer_0_read_enable := false.B
154154
}
155-
when(weight_buffer_1_read_enable && ((weight_write_counter === read_row_addr))){
155+
when(weight_buffer_1_read_enable && ((write_row_addr_w === read_row_addr))){
156156
weight_buffer_1_read_enable := false.B
157157
}
158158
}
@@ -188,8 +188,8 @@ class ScalingFactorMem(
188188
read_fire_real && act_buffer_1_read_enable && weight_buffer_1_read_enable && (act_bank_sel === 3.U), // bank 3
189189
read_fire_real && weight_buffer_0_read_enable && act_buffer_0_read_enable && (weight_bank_sel === 0.U), // bank 4
190190
read_fire_real && weight_buffer_0_read_enable && act_buffer_0_read_enable && (weight_bank_sel === 1.U), // bank 5
191-
read_fire_real && weight_buffer_1_read_enable && act_buffer_0_read_enable && (weight_bank_sel === 2.U), // bank 6
192-
read_fire_real && weight_buffer_1_read_enable && act_buffer_0_read_enable && (weight_bank_sel === 3.U) // bank 7
191+
read_fire_real && weight_buffer_1_read_enable && weight_buffer_1_read_enable && (weight_bank_sel === 2.U), // bank 6
192+
read_fire_real && weight_buffer_1_read_enable && weight_buffer_1_read_enable && (weight_bank_sel === 3.U) // bank 7
193193
))
194194

195195

0 commit comments

Comments
 (0)