Skip to content

Commit 350d8d6

Browse files
Amanda ShiAmanda Shi
authored andcommitted
Unify requantizer interface across modes
1 parent d5c827a commit 350d8d6

File tree

5 files changed

+42
-95
lines changed

5 files changed

+42
-95
lines changed

software/gemmini-rocc-tests

src/main/scala/gemmini/Controller.scala

Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
224224
spad.module.io.scale_mem_write_w.get <> mx_io.scale_mem_write_w
225225
spad.module.io.scale_mem_write_act.get <> mx_io.scale_mem_write_act
226226

227-
227+
mx_requantizer.get.io.requant_data_in_gpu <> mx_io.requant_in_gpu
228228
mx_io.scale_factor_out <> mx_requantizer.get.io.scaleMem_write
229229
mx_io.requant_out <> mx_requantizer.get.io.requant_data_out
230230
mx_requantizer.get.io.lut0_write <> mx_io.lut0
@@ -234,15 +234,6 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
234234
mx_io
235235
}
236236

237-
238-
// spad.module.io.scale_mem_write_act.foreach { ch =>
239-
// ch.valid := false.B
240-
// ch.bits := DontCare
241-
// }
242-
// spad.module.io.scale_mem_write_w.foreach { ch =>
243-
// ch.valid := false.B
244-
// ch.bits := DontCare
245-
// }
246237

247238
val lut_deprojected_data = Wire(Vec(sp_banks, new ScratchpadReadIO(sp_bank_entries, sp_width)))
248239
lut_deprojected_data := 0.U.asTypeOf(lut_deprojected_data)
@@ -257,11 +248,6 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
257248
mx_sel(bank) := (ex_controller.io.output_MxFormat === 1.U)
258249
}
259250
}
260-
// if(outer.config.lut.isDefined){
261-
// mx_requantizer.get.io.lut0_write := DontCare
262-
// mx_requantizer.get.io.lut1_write := DontCare
263-
// mx_requantizer.get.io.lut2_write := DontCare
264-
// }
265251
}
266252

267253
/*
@@ -321,6 +307,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
321307
(outer.config.sp_width_projected / (outer.config.aligned_to * 8)) max 1)))
322308

323309
val elements_per_bank = outer.config.sp_width_projected / outer.config.weightType.getWidth
310+
324311
when(ex_controller.io.enable_MXQuant =/= 0.U) {
325312
for (i <- 0 until outer.config.sp_banks) {
326313
requantized_writes(i).valid := false.B
@@ -329,13 +316,10 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
329316
requantized_writes(i).mask := VecInit(Seq.fill(requantized_writes(i).mask.length)(true.B))
330317
}
331318

332-
if (!outer.config.testConfig) {
333-
mx_io.get.requant_in_gpu.ready := false.B
334-
}
319+
335320
mx_requantizer.get.io.requant_data_in.valid := false.B
336321
mx_requantizer.get.io.requant_data_in.bits := DontCare
337-
mx_requantizer.get.io.scaleMem_write.ready := false.B
338-
322+
339323
when(ex_controller.io.output_MxFormat === 2.U){
340324
mx_requantizer.get.io.fp8_mode := true.B
341325
}.otherwise{
@@ -364,11 +348,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
364348

365349

366350
}.elsewhen(any_valid) {
367-
if (!outer.config.testConfig) {
368-
mx_io.get.requant_in_gpu.ready := false.B
369-
}
370-
371-
val collected_data = Wire(Vec(outer.config.requantizer.get.numInputLanes, UInt(outer.config.weightType.getWidth.W)))
351+
val collected_data = Wire(Vec(outer.config.requantizer.get.numOutputLanes, UInt(outer.config.weightType.getWidth.W)))
372352
collected_data := DontCare
373353

374354
var data_offset = 0
@@ -378,7 +358,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
378358

379359
when(ex_controller.io.srams.write(bank).valid) {
380360
for (elem_idx <- 0 until elements_per_bank) {
381-
if (data_offset + elem_idx < outer.config.requantizer.get.numInputLanes) {
361+
if (data_offset + elem_idx < outer.config.requantizer.get.numOutputLanes) {
382362
collected_data(data_offset + elem_idx) := bank_data(elem_idx)
383363
}
384364
}
@@ -403,19 +383,9 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
403383
mx_requantizer.get.io.requant_data_in.bits.dataType := RequantizerDataType.FP8
404384
}
405385

406-
}//.elsewhen(mx_io.get.requant_in_gpu.valid) {
386+
}
407387
.elsewhen(!any_valid) {
408-
409-
// val padded_data = VecInit(mx_io.get.requant_in_gpu.bits.data ++
410-
// Seq.fill(64 - 16)(0.U(16.W)))
411-
// mx_requantizer.get.io.requant_data_in.valid := true.B
412-
// mx_requantizer.get.io.requant_data_in.bits.data := padded_data
413-
// mx_requantizer.get.io.requant_data_in.bits.address := mx_io.get.requant_in_gpu.bits.address
414-
// mx_requantizer.get.io.requant_data_in.bits.dataType := mx_io.get.requant_in_gpu.bits.dataType
415-
if (!outer.config.testConfig) {
416-
mx_io.get.requant_in_gpu.ready := true.B
417-
}
418-
val padded_data = VecInit(Seq.fill(64)(0.U(16.W)))
388+
val padded_data = VecInit(Seq.fill(32)(0.U(16.W)))
419389
mx_requantizer.get.io.requant_data_in.valid := true.B
420390
mx_requantizer.get.io.requant_data_in.bits.data := padded_data
421391
mx_requantizer.get.io.requant_data_in.bits.address := 0.U
@@ -424,18 +394,13 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
424394
}.otherwise {
425395
mx_requantizer.get.io.requant_data_in.valid := false.B
426396
mx_requantizer.get.io.requant_data_in.bits := DontCare
427-
if (!outer.config.testConfig) {
428-
mx_io.get.requant_in_gpu.ready := false.B
429-
}
397+
430398
}
431399

432400
}.otherwise {
433401
// enable_MXQuant == 0
434402
mx_requantizer.get.io.requant_data_in.valid := false.B
435403
mx_requantizer.get.io.requant_data_in.bits := DontCare
436-
if (!outer.config.testConfig) {
437-
mx_io.get.requant_in_gpu.ready := false.B
438-
}
439404
}
440405

441406
requantized_writes

src/main/scala/gemmini/ExecuteController.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
122122
val preload_cmd_place = Mux(DoPreloads(0), 0.U, 1.U)
123123
// val a_address_place = Mux(current_dataflow === Dataflow.WS.id.U, 0.U, Mux(preload_cmd_place === 0.U, 1.U, 2.U))
124124

125-
val scale_mem_mvin_base_addr_act = RegInit(0.U(33.W))
126-
val scale_mem_mvin_base_addr_w = RegInit(0.U(33.W))
127-
val scale_mem_mvout_base_addr_act = RegInit(0.U(33.W))
125+
val scale_mem_mvin_base_addr_act = RegInit(0.U(32.W))
126+
val scale_mem_mvin_base_addr_w = RegInit(0.U(32.W))
127+
val scale_mem_mvout_base_addr_act = RegInit(0.U(32.W))
128128

129129
when(functs(0) === CONFIG_SCALE_MEM) {
130130
val direction = rs2s(0)(63)

src/main/scala/gemmini/GemminiISA.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ object GemminiISA {
6666
//==========================================================================
6767
val GARBAGE_ADDR = "hffffffff".U(32.W)
6868

69-
val CONFIG_SCALE_MEM_RS1_ADDR_WIDTH = 33
69+
val CONFIG_SCALE_MEM_RS1_ADDR_WIDTH = 32
7070
val CONFIG_SCALE_MEM_SPACER_WIDTH = 64 - 1 - CONFIG_SCALE_MEM_RS1_ADDR_WIDTH
7171

7272
class ConfigScaleMemRs1 extends Bundle {

src/main/scala/gemmini/MxRequantizer.scala

Lines changed: 28 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ class MxRequantizerIO(
5555
val inputnumLanes = config.numInputLanes
5656
val outputnumLanes = config.numOutputLanes
5757
val inputdataWidth = config.inputBits
58-
val requant_data_in = Flipped(Decoupled(new RequantizerInBundle(inputnumLanes, inputdataWidth)))
58+
val requant_data_in = Flipped(Decoupled(new RequantizerInBundle(outputnumLanes, inputdataWidth)))
59+
val requant_data_in_gpu = Flipped(Decoupled(new RequantizerInBundle(config.numGPUInputLanes, inputdataWidth)))
5960
val scaleMem_write = Decoupled(new ScalingFactorWriteReq(scaleMem_addr_width, scaleMem_data_width))
6061
val requant_data_out = Decoupled(new RequantizerOutBundle(outputnumLanes))
6162
val lut0_write = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig)))
@@ -66,7 +67,7 @@ class MxRequantizerIO(
6667
val fp8_mode = Input(Bool()) // true for 64-lane mode, false for 16-lane mode
6768
val a_fire = Input(Bool()) // from execute controller
6869
val b_fire = Input(Bool()) // from execute controller
69-
val scale_mem_mvout_base_addr_act = Input(UInt(33.W)) // from execute controller
70+
val scale_mem_mvout_base_addr_act = Input(UInt(32.W)) // from execute controller
7071
val counter_i = Input(UInt(iterator_bitwidth.W)) // from controller
7172
val counter_j = Input(UInt(iterator_bitwidth.W)) // from controller
7273
val counter_k = Input(UInt(iterator_bitwidth.W)) // from controller
@@ -105,6 +106,7 @@ class MxRequantizer[T <: Data: Arithmetic](
105106
iterator_bitwidth,
106107
config
107108
))
109+
dontTouch(io)
108110
val scale_mem_mvout_base_addr_act = io.scale_mem_mvout_base_addr_act
109111

110112
val scales_per_write = scaleMem_data_width / 8
@@ -143,61 +145,41 @@ class MxRequantizer[T <: Data: Arithmetic](
143145
}
144146

145147
val (exp_bits, mant_bits, pmax, log2_pmax_floor) = MxFloatFormat(format_reg)
146-
147-
val data_buffer = WireInit(VecInit(Seq.fill(io.outputnumLanes)(0.U(io.inputdataWidth.W))))
148148
val data_buffer_counter = RegInit(0.U(1.W))
149-
150149
//buffer twice for 16-lane mode
151150
val half_lanes = 16
152151
val input_32_buffer = RegInit(VecInit(Seq.fill(io.outputnumLanes)(0.U(io.inputdataWidth.W))))
153-
154-
val input_64_buffer = RegInit(VecInit(Seq.fill(io.inputnumLanes)(0.U(io.inputdataWidth.W))))
155-
val batch_counter = RegInit(0.U(1.W))
156-
val processing_64lane = RegInit(false.B)
157-
158-
val requant_data_in_valid_d = RegNext(io.requant_data_in.valid, false.B)
152+
val requant_data_in_valid_d = RegNext(io.requant_data_in.fire)
153+
val requant_data_in_gpu_valid_d = RegNext(io.requant_data_in_gpu.fire)
159154
val should_compute = Wire(Bool())
160-
val quantize_valid = RegNext(should_compute, false.B)
155+
val quantize_valid = RegNext(should_compute)
161156

157+
io.requant_data_in.ready := true.B
162158
should_compute := false.B
163-
io.requant_data_in.ready := !processing_64lane
164-
165-
when(io.requant_data_in.fire) {
166-
when(io.fp8_mode) { //16 lanes at a time
167-
for (i <- 0 until half_lanes) {
168-
val idx = Mux(data_buffer_counter === 0.U, i.U, (half_lanes + i).U)
169-
input_32_buffer(idx) := io.requant_data_in.bits.asUInt.asTypeOf(Vec(half_lanes, UInt(io.inputdataWidth.W)))(i)
170-
}
171-
data_buffer_counter := data_buffer_counter ^ 1.U
172-
}.otherwise {
173-
for (i <- 0 until 64) {
174-
input_64_buffer(i) := io.requant_data_in.bits.data(i)
159+
io.requant_data_in_gpu.ready := true.B
160+
161+
162+
when(io.requant_data_in.fire) {{
163+
io.requant_data_in_gpu.ready := false.B
164+
for (i <- 0 until 32) {
165+
input_32_buffer(i) := io.requant_data_in.bits.data(i)
175166
}
176167
data_buffer_counter := 1.U
177168
}
178-
}
179-
180-
when(io.fp8_mode) { //16 lanes at a time
181-
should_compute := data_buffer_counter === 0.U && requant_data_in_valid_d
182-
for (i <- 0 until io.outputnumLanes) {
183-
data_buffer(i) := input_32_buffer(i)
184-
}
185-
}.otherwise {
186-
processing_64lane := true.B
187-
should_compute := processing_64lane
188-
for (i <- 0 until io.outputnumLanes) {
189-
val idx = Mux(batch_counter === 0.U, i.U, (io.outputnumLanes + i).U)
190-
data_buffer(i) := input_64_buffer(i)
191-
}
192-
when(quantize_valid){
193-
batch_counter := 1.U
194-
}.otherwise {
195-
processing_64lane := false.B
196-
data_buffer_counter := 0.U
197-
batch_counter := 0.U
169+
}.elsewhen(io.requant_data_in_gpu.fire) {
170+
for (i <- 0 until half_lanes) {
171+
val idx = Mux(data_buffer_counter === 0.U, i.U, (half_lanes + i).U)
172+
input_32_buffer(idx) := io.requant_data_in_gpu.bits.data(i)
198173
}
174+
data_buffer_counter := ~data_buffer_counter
199175
}
200176

177+
val data_buffer = WireInit(VecInit(Seq.fill(io.outputnumLanes)(0.U(io.inputdataWidth.W))))
178+
data_buffer := input_32_buffer
179+
180+
// when(requant_data_in_valid_d || (requant_data_in_gpu_valid_d && (data_buffer_counter === 0.U))) {
181+
// should_compute := true.B
182+
// }
201183

202184
val block_max = Wire(UInt(io.inputdataWidth.W))
203185
block_max := 0.U
@@ -356,14 +338,14 @@ class MxRequantizer[T <: Data: Arithmetic](
356338

357339
when(scale_buffer_full) {
358340
io.scaleMem_write.valid := true.B
359-
io.scaleMem_write.bits.addr := scale_mem_mvout_base_addr_act +& (scale_write_addr_counter << 5.U) //byte address, scale 32B per write
341+
io.scaleMem_write.bits.addr := scale_mem_mvout_base_addr_act + (scale_write_addr_counter << 5) //byte address, scale 32B per write
360342
io.scaleMem_write.bits.data := Cat(scale_buffer.reverse)
361343

362344
when(io.scaleMem_write.fire) {
363345
val scale_buffer_packed = Cat(scale_buffer.reverse)
364346
printf(p"[MxScaleGen]: addr=${scale_write_addr_counter}, data=0x${Hexadecimal(scale_buffer_packed)}\n")
365347

366-
when(scale_write_addr_counter === ((1 << scaleMem_addr_width) - 1).U) {
348+
when(scale_write_addr_counter === ((1 << 10) - 1).U) {
367349
scale_write_addr_counter := 0.U
368350
}.otherwise {
369351
scale_write_addr_counter := scale_write_addr_counter + 1.U

0 commit comments

Comments
 (0)