@@ -192,6 +192,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
192192 sp_banks = outer.config.sp_banks,
193193 sp_width = outer.config.sp_width,
194194 sp_width_projected = outer.config.sp_width_projected,
195+ iterator_bitwidth = 16 ,
195196 config = q
196197 ))
197198 }
@@ -201,6 +202,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
201202 req.io.fp8_mode := false .B
202203}
203204
205+
204206 val mx_io = Option .when(outer.config.use_mx_scaling && outer.config.requantizer.isDefined && outer.config.lut.isDefined) {
205207 val q = outer.config.requantizer.get
206208 val l = outer.config.lut.get
@@ -213,7 +215,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
213215 val lut2 = Flipped (Decoupled (new QuantLutWriteBundle (l)))
214216 })
215217
216- mx_io.scale_mem <> spad.module.io.scale_mem.get
218+ // mx_io.scale_mem <> spad.module.io.scale_mem.get
217219
218220 mx_io.requant_out <> mx_requantizer.get.io.requant_data_out
219221 mx_requantizer.get.io.lut0_write <> mx_io.lut0
@@ -245,22 +247,27 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
245247 mx_sel(bank) := (ex_controller.io.output_MxFormat === 1 .U )
246248 }
247249 }
248- // if(outer.config.lut.isDefined){
249- // mx_requantizer.get.io.lut_write := DontCare
250- // }
250+ if (outer.config.lut.isDefined){
251+ mx_requantizer.get.io.lut_write_0 := DontCare
252+ mx_requantizer.get.io.lut_write_1 := DontCare
253+ mx_requantizer.get.io.lut_write_2 := DontCare
254+ }
251255 }
252256
253257 for (bank <- 0 until sp_banks) {
254258 val useMxB = mx_requantizer.isDefined.B && mx_sel(bank)
255259 // Requests
256260 // default
261+
262+
257263 read_projected(bank).req.valid := sram_read_buffer(bank).req.valid && ! useMxB
258264 read_projected(bank).req.bits := sram_read_buffer(bank).req.bits
259265 sram_read_buffer(bank).req.ready := Mux (useMxB, mx_requantizer.get.io.spad_projected_data(bank).req.ready, read_projected(bank).req.ready)
260266 // mx
267+
261268 mx_requantizer.get.io.spad_deprojected_data(bank).req.valid := sram_read_buffer(bank).req.valid && useMxB
262269 mx_requantizer.get.io.spad_deprojected_data(bank).req.bits := sram_read_buffer(bank).req.bits
263-
270+
264271 // Responses
265272
266273 read_projected(bank).resp.valid := Mux (useMxB, mx_requantizer.get.io.spad_projected_data(bank).resp.valid, sram_read_buffer(bank).resp.valid)
@@ -293,7 +300,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
293300 requantized_writes(i).mask := VecInit (Seq .fill(requantized_writes(i).mask.length)(true .B ))
294301 }
295302
296- mx_io.get.requant_in_gpu.ready := false .B
303+ // mx_io.get.requant_in_gpu.ready := false.B
297304 mx_requantizer.get.io.requant_data_in.valid := false .B
298305 mx_requantizer.get.io.requant_data_in.bits := DontCare
299306 mx_requantizer.get.io.scaleMem_write.ready := false .B
@@ -326,7 +333,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
326333
327334
328335 }.elsewhen(any_valid) {
329- mx_io.get.requant_in_gpu.ready := false .B
336+ // mx_io.get.requant_in_gpu.ready := false.B
330337
331338 val collected_data = Wire (Vec (outer.config.requantizer.get.numInputLanes, UInt (outer.config.weightType.getWidth.W )))
332339 collected_data := DontCare
@@ -363,33 +370,33 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
363370 mx_requantizer.get.io.requant_data_in.bits.dataType := RequantizerDataType .FP8
364371 }
365372
366- }.elsewhen(mx_io.get.requant_in_gpu.valid) {
367- // .elsewhen(!any_valid) {
373+ }// .elsewhen(mx_io.get.requant_in_gpu.valid) {
374+ .elsewhen(! any_valid) {
368375
369- val padded_data = VecInit (mx_io.get.requant_in_gpu.bits.data ++
370- Seq .fill(64 - 16 )(0 .U (16 .W )))
371- mx_requantizer.get.io.requant_data_in.valid := true .B
372- mx_requantizer.get.io.requant_data_in.bits.data := padded_data
373- mx_requantizer.get.io.requant_data_in.bits.address := mx_io.get.requant_in_gpu.bits.address
374- mx_requantizer.get.io.requant_data_in.bits.dataType := mx_io.get.requant_in_gpu.bits.dataType
375- mx_io.get.requant_in_gpu.ready := true .B
376- // val padded_data = VecInit(Seq.fill(64)(0.U(16.W)))
376+ // val padded_data = VecInit(mx_io.get.requant_in_gpu.bits.data ++
377+ // Seq.fill(64 - 16)(0.U(16.W)))
377378 // mx_requantizer.get.io.requant_data_in.valid := true.B
378379 // mx_requantizer.get.io.requant_data_in.bits.data := padded_data
379- // mx_requantizer.get.io.requant_data_in.bits.address := 0.U
380- // mx_requantizer.get.io.requant_data_in.bits.dataType := RequantizerDataType.FP8
380+ // mx_requantizer.get.io.requant_data_in.bits.address := mx_io.get.requant_in_gpu.bits.address
381+ // mx_requantizer.get.io.requant_data_in.bits.dataType := mx_io.get.requant_in_gpu.bits.dataType
382+ // mx_io.get.requant_in_gpu.ready := true.B
383+ val padded_data = VecInit (Seq .fill(64 )(0 .U (16 .W )))
384+ mx_requantizer.get.io.requant_data_in.valid := true .B
385+ mx_requantizer.get.io.requant_data_in.bits.data := padded_data
386+ mx_requantizer.get.io.requant_data_in.bits.address := 0 .U
387+ mx_requantizer.get.io.requant_data_in.bits.dataType := RequantizerDataType .FP8
381388
382389 }.otherwise {
383390 mx_requantizer.get.io.requant_data_in.valid := false .B
384391 mx_requantizer.get.io.requant_data_in.bits := DontCare
385- mx_io.get.requant_in_gpu.ready := false .B
392+ // mx_io.get.requant_in_gpu.ready := false.B
386393 }
387394
388395 }.otherwise {
389396 // enable_MXQuant == 0
390397 mx_requantizer.get.io.requant_data_in.valid := false .B
391398 mx_requantizer.get.io.requant_data_in.bits := DontCare
392- mx_io.get.requant_in_gpu.ready := false .B
399+ // mx_io.get.requant_in_gpu.ready := false.B
393400 }
394401
395402 requantized_writes
@@ -502,12 +509,19 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
502509 has_training_convs, has_max_pool, has_first_layer_optimizations, has_dw_convs) }
503510 else (raw_cmd, false .B )
504511
505- val (loop_cmd, loop_matmul_unroller_busy, loop_completed) = withClock (gated_clock) { LoopMatmul (if (has_loop_conv) conv_cmd else raw_cmd, reservation_station.io.matmul_ld_completed, reservation_station.io.matmul_st_completed, reservation_station.io.matmul_ex_completed,
512+ val (loop_cmd, loop_matmul_unroller_busy, loop_completed, loop_matmul ) = withClock (gated_clock) { LoopMatmul (if (has_loop_conv) conv_cmd else raw_cmd, reservation_station.io.matmul_ld_completed, reservation_station.io.matmul_st_completed, reservation_station.io.matmul_ex_completed,
506513 meshRows* tileRows, coreMaxAddrBits, reservation_station_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries,
507514 inputType.getWidth, accType.getWidth, dma_maxbytes, new MvinRs2 (mvin_rows_bits, mvin_cols_bits, local_addr_t),
508515 new PreloadRs (mvin_rows_bits, mvin_cols_bits, local_addr_t), new PreloadRs (mvout_rows_bits, mvout_cols_bits, local_addr_t),
509516 new ComputeRs (mvin_rows_bits, mvin_cols_bits, local_addr_t), new ComputeRs (mvin_rows_bits, mvin_cols_bits, local_addr_t),
510517 new MvoutSpadRs1 (32 , local_addr_t), new MvoutRs2 (mvout_rows_bits, mvout_cols_bits, local_addr_t)) }
518+
519+ mx_requantizer.get.io.counter_i := loop_matmul.io.counter_i
520+ mx_requantizer.get.io.counter_j := loop_matmul.io.counter_j
521+ mx_requantizer.get.io.counter_k := loop_matmul.io.counter_k
522+ mx_requantizer.get.io.a_fire := ex_controller.io.a_fire
523+ mx_requantizer.get.io.b_fire := ex_controller.io.b_fire
524+
511525
512526 val unrolled_cmd = Queue (loop_cmd)
513527 unrolled_cmd.ready := false .B
0 commit comments