Skip to content

Commit f678ab7

Browse files
committed
Merge branch 'gemmini-mx' of https://github.com/ucb-bar/gemmini into gemmini-mx
2 parents e04188a + b620c1e commit f678ab7

File tree

9 files changed

+155
-89
lines changed

9 files changed

+155
-89
lines changed

software/gemmini-rocc-tests

src/main/scala/gemmini/AccumulatorMem.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class AccumulatorWriteReq[T <: Data: Arithmetic](n: Int, t: Vec[Vec[T]]) extends
4747

4848

4949
class AccumulatorMemIO [T <: Data: Arithmetic, U <: Data](n: Int, t: Vec[Vec[T]], scale_t: U,
50-
acc_sub_banks: Int, use_shared_ext_mem: Boolean, use_mx_scaling: Boolean, meshRows: Int, tileRows: Int, sramLineSizeInBytes: Int
50+
acc_sub_banks: Int, use_shared_ext_mem: Boolean, use_mx_scaling: Boolean, meshRows: Int, tileRows: Int, bankWidthBits: Int
5151
) extends Bundle {
5252
val read = Flipped(new AccumulatorReadIO(n, t, scale_t))
5353
val write = Flipped(Decoupled(new AccumulatorWriteReq(n, t)))
@@ -66,10 +66,10 @@ class AccumulatorMemIO [T <: Data: Arithmetic, U <: Data](n: Int, t: Vec[Vec[T]]
6666

6767
val dataType = Input(UInt(2.W)) //this is the input mxformat datatype
6868
val scale_mem_write_act = if (use_mx_scaling) {
69-
Some(Flipped(Decoupled(new ScalingFactorWriteReq(9, 256))))
69+
Some(Flipped(Decoupled(new ScalingFactorWriteReq(13, 64))))
7070
} else None
7171
val scale_mem_write_w = if (use_mx_scaling) {
72-
Some(Flipped(Decoupled(new ScalingFactorWriteReq(9, 256))))
72+
Some(Flipped(Decoupled(new ScalingFactorWriteReq(13, 64))))
7373
} else None
7474
val scaleMemCntl = if (use_mx_scaling) {
7575
Some(Input(new ScalingFactorCntl(meshRows * tileRows)))
@@ -127,17 +127,22 @@ class AccumulatorMem[T <: Data, U <: Data](
127127
import ev._
128128

129129
// TODO unify this with TwoPortSyncMemIO
130-
val io = IO(new AccumulatorMemIO(n, t, scale_t, acc_sub_banks, use_shared_ext_mem, use_mx_scaling, meshRows, tileRows, scale_mem.get.sramLineSizeInBytes))
130+
val io = IO(new AccumulatorMemIO(n, t, scale_t, acc_sub_banks, use_shared_ext_mem, use_mx_scaling, meshRows, tileRows, scale_mem.get.bankWidthBits))
131131

132132
val scaleFactorMem = scale_mem.map { conf =>
133+
// println(s"[ScalingFactorMem Config]")
134+
// println(s" depth = ${conf.depth}")
135+
// println(s" subbankLineSizeInBytes = ${conf.subbankLineSizeInBytes}")
136+
// println(s" bankWidthBits = ${conf.bankWidthBits}")
137+
// println(s" numBanks = ${conf.numBanks}")
133138
Module(new ScalingFactorMem(
134139
depth = conf.depth,
135-
bankWidth = conf.bankWidthBits,
140+
sramWidth = conf.subbankLineSizeInBytes*8,
136141
actOutputScalingWidth = 8,
137142
numBanks = conf.numBanks,
138143
testConfig = testConfig,
139144
meshRows = meshRows,
140-
tileRows = tileRows,
145+
tileRows = tileRows
141146
))
142147
}
143148

src/main/scala/gemmini/ConfigsFP.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ object GemminiMxFPConfigs {
274274

275275
// 16x16 mesh with varying precisions
276276
scaleMem_data_width = 128,
277-
scaleMem_bank_entries = 256,
277+
scaleMem_bank_entries = 8192,
278278
scaleSize = 32,
279279
enable_lut = true,
280280
mvin_scale_args = None,
@@ -305,6 +305,7 @@ object GemminiMxFPConfigs {
305305

306306
num_counter = 8,
307307
requantizer = Some(GemminiRequantizerConfig(
308+
baseAddr = 0x10000000L,
308309
numInputLanes = 64,
309310
numOutputLanes = 32,
310311
gpuMaxFactor = 2,
@@ -315,8 +316,10 @@ object GemminiMxFPConfigs {
315316
outputIdBits = 3
316317
)),
317318
scale_mem = Some(GemminiScalingFactorMemConfig(
319+
baseAddr = 0x10000000L + 0x8000,
318320
sizeInBytes = 16 << 10,
319-
sramLineSizeInBytes = 256 / 8,
321+
subbankLineSizeInBytes = 16,
322+
subbanksPerBank = 2,
320323
numBanks = 8,
321324
))
322325
)

src/main/scala/gemmini/Controller.scala

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -217,32 +217,32 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
217217
val lut0 = Flipped(Decoupled(new QuantLutWriteBundle(l)))
218218
val lut1 = Flipped(Decoupled(new QuantLutWriteBundle(l)))
219219
val lut2 = Flipped(Decoupled(new QuantLutWriteBundle(l)))
220+
//val scaleFactorOut = Decoupled(new ScalingFactorWriteReq(scaleMem_addr_width, scaleMem_data_width))
220221
})
221222

222-
if (!outer.config.testConfig) {
223-
mx_io.scale_mem_write_w <> spad.module.io.scale_mem_write_w.get
224-
mx_io.scale_mem_write_act <> spad.module.io.scale_mem_write_act.get
225-
}
226-
223+
224+
spad.module.io.scale_mem_write_w.get <> mx_io.scale_mem_write_w
225+
spad.module.io.scale_mem_write_act.get <> mx_io.scale_mem_write_act
226+
227+
228+
//mx_io.scaleFactorOut <> mx_requantizer.get.io.scaleMem_write
227229
mx_io.requant_out <> mx_requantizer.get.io.requant_data_out
228230
mx_requantizer.get.io.lut0_write <> mx_io.lut0
229231
mx_requantizer.get.io.lut1_write <> mx_io.lut1
230232
mx_requantizer.get.io.lut2_write <> mx_io.lut2
231233

232-
// Seq(mx_io.requant_in_gpu, mx_io.requant_out, mx_io.lut0, mx_io.lut1, mx_io.lut2).foreach(dontTouch(_))
233-
//Seq( mx_io.requant_out).foreach(dontTouch(_))
234234
mx_io
235235
}
236236

237237

238-
spad.module.io.scale_mem_write_act.foreach { ch =>
239-
ch.valid := false.B
240-
ch.bits := DontCare
241-
}
242-
spad.module.io.scale_mem_write_w.foreach { ch =>
243-
ch.valid := false.B
244-
ch.bits := DontCare
245-
}
238+
// spad.module.io.scale_mem_write_act.foreach { ch =>
239+
// ch.valid := false.B
240+
// ch.bits := DontCare
241+
// }
242+
// spad.module.io.scale_mem_write_w.foreach { ch =>
243+
// ch.valid := false.B
244+
// ch.bits := DontCare
245+
// }
246246

247247
val lut_deprojected_data = Wire(Vec(sp_banks, new ScratchpadReadIO(sp_bank_entries, sp_width)))
248248
lut_deprojected_data := 0.U.asTypeOf(lut_deprojected_data)

src/main/scala/gemmini/ExecuteController.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,12 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
127127
val scale_mem_mvout_base_addr_act = RegInit(0.U(33.W))
128128

129129
when(functs(0) === CONFIG_SCALE_MEM) {
130-
val addr = rs1s(0).asTypeOf(new ConfigScaleMemRs1)
131-
val addr_direction = addr.mem_direction
132-
when(addr_direction === 0.U) { // mvin
133-
val act_scale_address_rs1 = addr.mem_address
134-
scale_mem_mvin_base_addr_act := act_scale_address_rs1
135-
scale_mem_mvin_base_addr_w := act_scale_address_rs1 + (config.scale_mem.get.sizeInBytes >> 1).U
136-
}.elsewhen(addr_direction === 1.U) { // mvout
137-
scale_mem_mvout_base_addr_act := addr.mem_address
130+
val direction = rs2s(0)(63)
131+
when(direction === 1.U) { // mvin
132+
scale_mem_mvin_base_addr_act := rs1s(0)
133+
scale_mem_mvin_base_addr_w := rs1s(0) + (config.scale_mem.get.sizeInBytes >> 1).U
134+
}.elsewhen(direction === 0.U) { // mvout
135+
scale_mem_mvout_base_addr_act := rs1s(0)
138136
}
139137
}
140138
io.scale_mem_mvout_base_addr_act := scale_mem_mvout_base_addr_act

src/main/scala/gemmini/GemminiConfigs.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
5757

5858

5959
scaleMem_data_width: Int = 128,
60-
scaleMem_bank_entries: Int = 256,
60+
scaleMem_bank_entries: Int = 8192,
6161
scaleSize: Int = 32,
6262

6363
dma_maxbytes: Int = 64, // TODO get this from cacheblockbytes

src/main/scala/gemmini/MxConfigFragments.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@ import chisel3.util._
66
case class GemminiScalingFactorMemConfig(
77
baseAddr: BigInt = 0x10000000L,
88
sizeInBytes: BigInt = 16 << 10,
9-
sramLineSizeInBytes: Int = 32,
9+
subbankLineSizeInBytes: Int = 16,
10+
subbanksPerBank: Int = 2,
11+
gpuInputWidthBytes: Int = 8,
1012
numBanks: Int = 8,
1113
) {
12-
def depth: Int = (sizeInBytes / sramLineSizeInBytes / numBanks).toInt
13-
def bankWidthBits = sramLineSizeInBytes * 8
14+
def depth: Int = (sizeInBytes / (subbankLineSizeInBytes) / numBanks).toInt
15+
def bankWidthBytes = subbankLineSizeInBytes * subbanksPerBank
16+
def bankWidthBits = bankWidthBytes * 8
1417
def addrBits = log2Ceil(sizeInBytes)
15-
def lineOffsetBits = log2Ceil(sramLineSizeInBytes)
18+
def lineOffsetBits = log2Ceil(bankWidthBytes)
1619
}
1720

1821
case class GemminiRequantizerConfig(
@@ -47,15 +50,17 @@ object RequantizerDataType extends ChiselEnum {
4750
}
4851
}
4952

53+
5054
class ScalingFactorWriteReq(addrWidth: Int, dataWidth: Int) extends Bundle {
5155
val addr = UInt(addrWidth.W)
5256
val data = UInt(dataWidth.W)
5357
def this(config: GemminiScalingFactorMemConfig) = {
5458
// writes two interleaved banks at once
55-
this(config.addrBits, config.bankWidthBits * 2)
59+
this(config.addrBits, 8*config.gpuInputWidthBytes*8)
5660
}
5761
}
5862

63+
5964
class ScalingFactorCntl(max_block: Int) extends Bundle {
6065
val counter_a = UInt(log2Up(max_block).W)
6166
val counter_b = UInt(log2Up(max_block).W)

src/main/scala/gemmini/MxRequantizer.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ class MxRequantizer[T <: Data: Arithmetic](
294294
// }.otherwise {
295295
// quantLut.io.lut_write.ready := false.B
296296
// }
297-
297+
298298
quantLut.io.lut_write_weight <> io.lut0_write
299299
quantLut.io.lut_write_act_in <> io.lut1_write
300300
quantLut.io.lut_write_act_out <> io.lut2_write
@@ -335,7 +335,6 @@ class MxRequantizer[T <: Data: Arithmetic](
335335

336336
val scale_write_counter = RegInit(0.U(log2Ceil(scaleSize).W))
337337
val scale_buffer_full = RegInit(false.B)
338-
339338
when(should_compute) {
340339
for (i <- 0 until scaleSize) {
341340
when(i.U === scale_write_counter) {
@@ -347,8 +346,12 @@ class MxRequantizer[T <: Data: Arithmetic](
347346
scale_buffer_full := true.B
348347
}.otherwise {
349348
scale_write_counter := scale_write_counter + 1.U
350-
scale_buffer_full := false.B
349+
when(io.scaleMem_write.ready){
350+
scale_buffer_full := false.B
351+
}
351352
}
353+
}.elsewhen(io.scaleMem_write.ready){
354+
scale_buffer_full := false.B
352355
}
353356

354357
when(scale_buffer_full) {

0 commit comments

Comments
 (0)