Skip to content

Commit 7cfba71

Browse files
Amanda ShiAmanda Shi
authored andcommitted
change the QuantLut as double buffer
1 parent c4fc2c6 commit 7cfba71

File tree

8 files changed

+435
-355
lines changed

8 files changed

+435
-355
lines changed

chipyard/GemminiConfigs.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,6 @@ class TestMxGemminiRocketConfig extends Config(
6060

6161
class TestRequantizerLutMxGemminiRocketConfig extends Config(
6262
new gemmini.GemminiRequantizerLutMxFPTestConfig ++ // use FP32Gemmini systolic array GEMM accelerator
63-
new freechips.rocketchip.rocket.WithNHugeCores(1) ++
64-
new chipyard.config.WithSystemBusWidth(128) ++
63+
new freechips.rocketchip.rocket.WithNSmallCores(1) ++
64+
new chipyard.config.WithSystemBusWidth(256) ++
6565
new chipyard.config.AbstractConfig)

src/main/scala/gemmini/Controller.scala

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
192192
sp_banks = outer.config.sp_banks,
193193
sp_width = outer.config.sp_width,
194194
sp_width_projected = outer.config.sp_width_projected,
195+
iterator_bitwidth = 16,
195196
config = q
196197
))
197198
}
@@ -201,6 +202,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
201202
req.io.fp8_mode := false.B
202203
}
203204

205+
204206
val mx_io = Option.when(outer.config.use_mx_scaling && outer.config.requantizer.isDefined && outer.config.lut.isDefined) {
205207
val q = outer.config.requantizer.get
206208
val l = outer.config.lut.get
@@ -213,7 +215,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
213215
val lut2 = Flipped(Decoupled(new QuantLutWriteBundle(l)))
214216
})
215217

216-
mx_io.scale_mem <> spad.module.io.scale_mem.get
218+
//mx_io.scale_mem <> spad.module.io.scale_mem.get
217219

218220
mx_io.requant_out <> mx_requantizer.get.io.requant_data_out
219221
mx_requantizer.get.io.lut0_write <> mx_io.lut0
@@ -245,22 +247,27 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
245247
mx_sel(bank) := (ex_controller.io.output_MxFormat === 1.U)
246248
}
247249
}
248-
// if(outer.config.lut.isDefined){
249-
// mx_requantizer.get.io.lut_write := DontCare
250-
// }
250+
if(outer.config.lut.isDefined){
251+
mx_requantizer.get.io.lut_write_0 := DontCare
252+
mx_requantizer.get.io.lut_write_1 := DontCare
253+
mx_requantizer.get.io.lut_write_2 := DontCare
254+
}
251255
}
252256

253257
for (bank <- 0 until sp_banks) {
254258
val useMxB = mx_requantizer.isDefined.B && mx_sel(bank)
255259
// Requests
256260
// default
261+
262+
257263
read_projected(bank).req.valid := sram_read_buffer(bank).req.valid && !useMxB
258264
read_projected(bank).req.bits := sram_read_buffer(bank).req.bits
259265
sram_read_buffer(bank).req.ready := Mux(useMxB, mx_requantizer.get.io.spad_projected_data(bank).req.ready, read_projected(bank).req.ready)
260266
// mx
267+
261268
mx_requantizer.get.io.spad_deprojected_data(bank).req.valid := sram_read_buffer(bank).req.valid && useMxB
262269
mx_requantizer.get.io.spad_deprojected_data(bank).req.bits := sram_read_buffer(bank).req.bits
263-
270+
264271
// Responses
265272

266273
read_projected(bank).resp.valid := Mux(useMxB, mx_requantizer.get.io.spad_projected_data(bank).resp.valid, sram_read_buffer(bank).resp.valid)
@@ -293,7 +300,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
293300
requantized_writes(i).mask := VecInit(Seq.fill(requantized_writes(i).mask.length)(true.B))
294301
}
295302

296-
mx_io.get.requant_in_gpu.ready := false.B
303+
//mx_io.get.requant_in_gpu.ready := false.B
297304
mx_requantizer.get.io.requant_data_in.valid := false.B
298305
mx_requantizer.get.io.requant_data_in.bits := DontCare
299306
mx_requantizer.get.io.scaleMem_write.ready := false.B
@@ -326,7 +333,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
326333

327334

328335
}.elsewhen(any_valid) {
329-
mx_io.get.requant_in_gpu.ready := false.B
336+
//mx_io.get.requant_in_gpu.ready := false.B
330337

331338
val collected_data = Wire(Vec(outer.config.requantizer.get.numInputLanes, UInt(outer.config.weightType.getWidth.W)))
332339
collected_data := DontCare
@@ -363,33 +370,33 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
363370
mx_requantizer.get.io.requant_data_in.bits.dataType := RequantizerDataType.FP8
364371
}
365372

366-
}.elsewhen(mx_io.get.requant_in_gpu.valid) {
367-
//.elsewhen(!any_valid) {
373+
}//.elsewhen(mx_io.get.requant_in_gpu.valid) {
374+
.elsewhen(!any_valid) {
368375

369-
val padded_data = VecInit(mx_io.get.requant_in_gpu.bits.data ++
370-
Seq.fill(64 - 16)(0.U(16.W)))
371-
mx_requantizer.get.io.requant_data_in.valid := true.B
372-
mx_requantizer.get.io.requant_data_in.bits.data := padded_data
373-
mx_requantizer.get.io.requant_data_in.bits.address := mx_io.get.requant_in_gpu.bits.address
374-
mx_requantizer.get.io.requant_data_in.bits.dataType := mx_io.get.requant_in_gpu.bits.dataType
375-
mx_io.get.requant_in_gpu.ready := true.B
376-
// val padded_data = VecInit(Seq.fill(64)(0.U(16.W)))
376+
// val padded_data = VecInit(mx_io.get.requant_in_gpu.bits.data ++
377+
// Seq.fill(64 - 16)(0.U(16.W)))
377378
// mx_requantizer.get.io.requant_data_in.valid := true.B
378379
// mx_requantizer.get.io.requant_data_in.bits.data := padded_data
379-
// mx_requantizer.get.io.requant_data_in.bits.address := 0.U
380-
// mx_requantizer.get.io.requant_data_in.bits.dataType := RequantizerDataType.FP8
380+
// mx_requantizer.get.io.requant_data_in.bits.address := mx_io.get.requant_in_gpu.bits.address
381+
// mx_requantizer.get.io.requant_data_in.bits.dataType := mx_io.get.requant_in_gpu.bits.dataType
382+
// mx_io.get.requant_in_gpu.ready := true.B
383+
val padded_data = VecInit(Seq.fill(64)(0.U(16.W)))
384+
mx_requantizer.get.io.requant_data_in.valid := true.B
385+
mx_requantizer.get.io.requant_data_in.bits.data := padded_data
386+
mx_requantizer.get.io.requant_data_in.bits.address := 0.U
387+
mx_requantizer.get.io.requant_data_in.bits.dataType := RequantizerDataType.FP8
381388

382389
}.otherwise {
383390
mx_requantizer.get.io.requant_data_in.valid := false.B
384391
mx_requantizer.get.io.requant_data_in.bits := DontCare
385-
mx_io.get.requant_in_gpu.ready := false.B
392+
//mx_io.get.requant_in_gpu.ready := false.B
386393
}
387394

388395
}.otherwise {
389396
// enable_MXQuant == 0
390397
mx_requantizer.get.io.requant_data_in.valid := false.B
391398
mx_requantizer.get.io.requant_data_in.bits := DontCare
392-
mx_io.get.requant_in_gpu.ready := false.B
399+
//mx_io.get.requant_in_gpu.ready := false.B
393400
}
394401

395402
requantized_writes
@@ -502,12 +509,19 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
502509
has_training_convs, has_max_pool, has_first_layer_optimizations, has_dw_convs) }
503510
else (raw_cmd, false.B)
504511

505-
val (loop_cmd, loop_matmul_unroller_busy, loop_completed) = withClock (gated_clock) { LoopMatmul(if (has_loop_conv) conv_cmd else raw_cmd, reservation_station.io.matmul_ld_completed, reservation_station.io.matmul_st_completed, reservation_station.io.matmul_ex_completed,
512+
val (loop_cmd, loop_matmul_unroller_busy, loop_completed, loop_matmul) = withClock (gated_clock) { LoopMatmul(if (has_loop_conv) conv_cmd else raw_cmd, reservation_station.io.matmul_ld_completed, reservation_station.io.matmul_st_completed, reservation_station.io.matmul_ex_completed,
506513
meshRows*tileRows, coreMaxAddrBits, reservation_station_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries,
507514
inputType.getWidth, accType.getWidth, dma_maxbytes, new MvinRs2(mvin_rows_bits, mvin_cols_bits, local_addr_t),
508515
new PreloadRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new PreloadRs(mvout_rows_bits, mvout_cols_bits, local_addr_t),
509516
new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t),
510517
new MvoutSpadRs1(32, local_addr_t), new MvoutRs2(mvout_rows_bits, mvout_cols_bits, local_addr_t)) }
518+
519+
mx_requantizer.get.io.counter_i := loop_matmul.io.counter_i
520+
mx_requantizer.get.io.counter_j := loop_matmul.io.counter_j
521+
mx_requantizer.get.io.counter_k := loop_matmul.io.counter_k
522+
mx_requantizer.get.io.a_fire := ex_controller.io.a_fire
523+
mx_requantizer.get.io.b_fire := ex_controller.io.b_fire
524+
511525

512526
val unrolled_cmd = Queue(loop_cmd)
513527
unrolled_cmd.ready := false.B

src/main/scala/gemmini/ExecuteController.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
5050
val enable_MXQuant = Output(Bool())
5151

5252
val counter = new CounterEventIO()
53+
val b_fire = Output(Bool())
54+
val a_fire = Output(Bool())
5355
})
5456

5557

@@ -196,7 +198,7 @@ def extractHalf(data: UInt, use_high_half: Bool): UInt = {
196198
// Dependency stuff
197199
io.completed.valid := false.B
198200
io.completed.bits := DontCare
199-
201+
200202
// val pending_completed_rob_id = Reg(UDValid(UInt(log2Up(rob_entries).W)))
201203
val pending_completed_rob_ids = Reg(Vec(2, UDValid(UInt(log2Up(reservation_station_entries).W))))
202204

@@ -499,6 +501,10 @@ def extractHalf(data: UInt, use_high_half: Bool): UInt = {
499501
io.srams.read(i).resp.ready := false.B
500502
}
501503

504+
505+
io.a_fire := a_fire
506+
io.b_fire := b_fire
507+
502508
// Accumulator read
503509
for (i <- 0 until acc_banks) {
504510
val read_a_from_acc = a_valid && a_read_from_acc && dataABankAcc === i.U && start_inputting_a && !multiply_garbage && a_row_is_not_all_zeros && !(im2col_wire&&im2col_en)

src/main/scala/gemmini/LoopMatmul.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,9 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size
906906
val ex_completed = Input(UInt(log2Up(reservation_station_size+1).W))
907907
val busy = Output(Bool())
908908
val completed = Output(Vec(2, Bool()))
909+
val counter_i = Output(UInt(16.W))
910+
val counter_j = Output(UInt(16.W))
911+
val counter_k = Output(UInt(16.W))
909912
})
910913

911914
// Create states
@@ -934,6 +937,9 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size
934937

935938
// Create command queue
936939
val cmd = Queue(io.in)
940+
io.counter_i := ex.io.i
941+
io.counter_j := ex.io.j
942+
io.counter_k := ex.io.k
937943

938944
io.busy := cmd.valid || loop_configured
939945

@@ -1306,15 +1312,15 @@ object LoopMatmul {
13061312
max_addr: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, dma_max_bytes: Int,
13071313
mvin_rs2_t: MvinRs2, preload_rs1_t: PreloadRs, preload_rs2_t: PreloadRs,
13081314
compute_rs1_t: ComputeRs, compute_rs2_t: ComputeRs, mvout_spad_rs1_t: MvoutSpadRs1, mvout_rs2_t: MvoutRs2)
1309-
(implicit p: Parameters): (DecoupledIO[GemminiCmd], Bool, Vec[Bool]) = {
1315+
(implicit p: Parameters): (DecoupledIO[GemminiCmd], Bool, Vec[Bool], LoopMatmul) = {
13101316
val mod = Module(new LoopMatmul(block_size, coreMaxAddrBits, rob_size, max_lds, max_exs, max_sts,
13111317
max_addr, max_acc_addr, input_w, acc_w, dma_max_bytes,
13121318
mvin_rs2_t, preload_rs1_t, preload_rs2_t, compute_rs1_t, compute_rs2_t, mvout_spad_rs1_t, mvout_rs2_t))
13131319
mod.io.in <> in
13141320
mod.io.ld_completed := ld_completed
13151321
mod.io.st_completed := st_completed
13161322
mod.io.ex_completed := ex_completed
1317-
(mod.io.out, mod.io.busy, mod.io.completed)
1323+
(mod.io.out, mod.io.busy, mod.io.completed, mod)
13181324
}
13191325

13201326
def castDramOffset(dram_offset: UInt): UInt = {

src/main/scala/gemmini/MxConfigFragments.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ case class GemminiRequantizerConfig(
2626
minOutputBits: Int = 4,
2727
maxOutputBits: Int = 8,
2828
outputIdBits: Int = 3,
29+
lutUpdateRegularityW : Int = 128,
30+
lutUpdateRegularityActIn : Int = 128,
31+
lutUpdateRegularityActOut : Int = 128,
2932
)
3033

3134
case class GemminiLUTConfig(

src/main/scala/gemmini/MxRequantizer.scala

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class MxRequantizerIO(
4949
sp_banks: Int,
5050
sp_width: Int,
5151
sp_width_projected: Int,
52+
iterator_bitwidth: Int,
5253
config: GemminiRequantizerConfig
5354
) extends Bundle {
5455
val inputnumLanes = config.numInputLanes
@@ -63,6 +64,11 @@ class MxRequantizerIO(
6364
val spad_projected_data = Vec(sp_banks, new ScratchpadReadIO(sp_bank_entries, sp_width_projected))
6465
val spad_deprojected_data = Vec(sp_banks, Flipped(new ScratchpadReadIO(sp_bank_entries, sp_width)))
6566
val fp8_mode = Input(Bool()) // true for 64-lane mode, false for 16-lane mode
67+
val a_fire = Input(Bool()) // from execute controller
68+
val b_fire = Input(Bool()) // from execute controller
69+
val counter_i = Input(UInt(iterator_bitwidth.W)) // from controller
70+
val counter_j = Input(UInt(iterator_bitwidth.W)) // from controller
71+
val counter_k = Input(UInt(iterator_bitwidth.W)) // from controller
6672
}
6773

6874
class MxRequantizer[T <: Data: Arithmetic](
@@ -77,6 +83,7 @@ class MxRequantizer[T <: Data: Arithmetic](
7783
sp_banks: Int,
7884
sp_width: Int,
7985
sp_width_projected: Int,
86+
iterator_bitwidth: Int,
8087
config: GemminiRequantizerConfig
8188
)(implicit ev: Arithmetic[T]) extends Module {
8289

@@ -94,6 +101,7 @@ class MxRequantizer[T <: Data: Arithmetic](
94101
sp_banks,
95102
sp_width,
96103
sp_width_projected,
104+
iterator_bitwidth,
97105
config
98106
))
99107

@@ -252,12 +260,19 @@ class MxRequantizer[T <: Data: Arithmetic](
252260
sp_bank_entries = sp_bank_entries,
253261
sp_banks = sp_banks,
254262
sp_width = sp_width,
255-
sp_width_projected = sp_width_projected
263+
sp_width_projected = sp_width_projected,
264+
lut_update_regularity_w = config.lutUpdateRegularityW,
265+
lut_update_regularity_act_in = config.lutUpdateRegularityActIn,
266+
lut_update_regularity_act_out = config.lutUpdateRegularityActOut,
267+
iterator_bitwidth = iterator_bitwidth
256268
))
257269

258270
quantLut.io.spad_projected_data <> io.spad_projected_data
259271
quantLut.io.spad_deprojected_data <> io.spad_deprojected_data
260-
272+
quantLut.io.a_fire := io.a_fire
273+
quantLut.io.b_fire := io.b_fire
274+
quantLut.io.counter_i := io.counter_i
275+
quantLut.io.counter_j := io.counter_j
261276
// quantLut.io.lut_write.valid := false.B
262277
// quantLut.io.lut_write.bits := DontCare
263278
quantLut.io.quant_fp6.valid := false.B
@@ -273,15 +288,37 @@ class MxRequantizer[T <: Data: Arithmetic](
273288
// quantLut.io.lut_write.ready := false.B
274289
// }
275290

291+
<<<<<<< HEAD
276292
quantLut.io.lut_write <> io.lut0_write
277293
io.lut1_write.ready := false.B
278294
io.lut2_write.ready := false.B
279295

296+
=======
297+
quantLut.io.lut_write_weight <> io.lut_write_0
298+
quantLut.io.lut_write_act_in <> io.lut_write_1
299+
quantLut.io.lut_write_act_out <> io.lut_write_2
300+
val quant_fp6_buffer = RegInit(VecInit(Seq.fill(io.outputnumLanes)(0.U(6.W))))
301+
val quant_fp6_hang = RegInit(false.B)
302+
>>>>>>> e1e04af (change the QuantLut as double buffer)
280303
when(quantize_valid && (total_bits_per_element === 6.U)) {
281-
quantLut.io.quant_fp6.valid := true.B
282-
quantLut.io.quant_fp6.bits := quant_fp6
304+
when{quantLut.io.lut_write_act_out.ready}{ //hang here when write is finished
305+
for (i <- 0 until io.outputnumLanes) {
306+
quant_fp6_buffer(i) := quant_fp6(i)
307+
}
308+
quantLut.io.quant_fp6.valid := false.B
309+
quant_fp6_hang := true.B
310+
}
311+
when{!quantLut.io.lut_write_act_out.ready } {
312+
when (quant_fp6_hang) {
313+
quantLut.io.quant_fp6.valid := true.B
314+
quantLut.io.quant_fp6.bits := quant_fp6_buffer
315+
quant_fp6_hang := false.B
316+
}.otherwise {
317+
quantLut.io.quant_fp6.valid := true.B
318+
quantLut.io.quant_fp6.bits := quant_fp6
319+
}
320+
}
283321
}
284-
285322
when(quantLut.io.projected_data.valid && (total_bits_per_element === 6.U)) {
286323
io.requant_data_out.valid := true.B
287324
io.requant_data_out.bits.dataType := quant_dataType

0 commit comments

Comments
 (0)