@@ -55,7 +55,8 @@ class MxRequantizerIO(
5555 val inputnumLanes = config.numInputLanes
5656 val outputnumLanes = config.numOutputLanes
5757 val inputdataWidth = config.inputBits
58- val requant_data_in = Flipped (Decoupled (new RequantizerInBundle (inputnumLanes, inputdataWidth)))
58+ val requant_data_in = Flipped (Decoupled (new RequantizerInBundle (outputnumLanes, inputdataWidth)))
59+ val requant_data_in_gpu = Flipped (Decoupled (new RequantizerInBundle (config.numGPUInputLanes, inputdataWidth)))
5960 val scaleMem_write = Decoupled (new ScalingFactorWriteReq (scaleMem_addr_width, scaleMem_data_width))
6061 val requant_data_out = Decoupled (new RequantizerOutBundle (outputnumLanes))
6162 val lut0_write = Flipped (Decoupled (new QuantLutWriteBundle (lutConfig)))
@@ -66,7 +67,7 @@ class MxRequantizerIO(
6667 val fp8_mode = Input (Bool ()) // true for 64-lane mode, false for 16-lane mode
6768 val a_fire = Input (Bool ()) // from execute controller
6869 val b_fire = Input (Bool ()) // from execute controller
69- val scale_mem_mvout_base_addr_act = Input (UInt (33 .W )) // from execute controller
70+ val scale_mem_mvout_base_addr_act = Input (UInt (32 .W )) // from execute controller
7071 val counter_i = Input (UInt (iterator_bitwidth.W )) // from controller
7172 val counter_j = Input (UInt (iterator_bitwidth.W )) // from controller
7273 val counter_k = Input (UInt (iterator_bitwidth.W )) // from controller
@@ -105,6 +106,7 @@ class MxRequantizer[T <: Data: Arithmetic](
105106 iterator_bitwidth,
106107 config
107108 ))
109+ dontTouch(io)
108110 val scale_mem_mvout_base_addr_act = io.scale_mem_mvout_base_addr_act
109111
110112 val scales_per_write = scaleMem_data_width / 8
@@ -143,61 +145,41 @@ class MxRequantizer[T <: Data: Arithmetic](
143145 }
144146
145147 val (exp_bits, mant_bits, pmax, log2_pmax_floor) = MxFloatFormat (format_reg)
146-
147- val data_buffer = WireInit (VecInit (Seq .fill(io.outputnumLanes)(0 .U (io.inputdataWidth.W ))))
148148 val data_buffer_counter = RegInit (0 .U (1 .W ))
149-
150149 // buffer twice for 16-lane mode
151150 val half_lanes = 16
152151 val input_32_buffer = RegInit (VecInit (Seq .fill(io.outputnumLanes)(0 .U (io.inputdataWidth.W ))))
153-
154- val input_64_buffer = RegInit (VecInit (Seq .fill(io.inputnumLanes)(0 .U (io.inputdataWidth.W ))))
155- val batch_counter = RegInit (0 .U (1 .W ))
156- val processing_64lane = RegInit (false .B )
157-
158- val requant_data_in_valid_d = RegNext (io.requant_data_in.valid, false .B )
152+ val requant_data_in_valid_d = RegNext (io.requant_data_in.fire)
153+ val requant_data_in_gpu_valid_d = RegNext (io.requant_data_in_gpu.fire)
159154 val should_compute = Wire (Bool ())
160- val quantize_valid = RegNext (should_compute, false . B )
155+ val quantize_valid = RegNext (should_compute)
161156
157+ io.requant_data_in.ready := true .B
162158 should_compute := false .B
163- io.requant_data_in.ready := ! processing_64lane
164-
165- when(io.requant_data_in.fire) {
166- when(io.fp8_mode) { // 16 lanes at a time
167- for (i <- 0 until half_lanes) {
168- val idx = Mux (data_buffer_counter === 0 .U , i.U , (half_lanes + i).U )
169- input_32_buffer(idx) := io.requant_data_in.bits.asUInt.asTypeOf(Vec (half_lanes, UInt (io.inputdataWidth.W )))(i)
170- }
171- data_buffer_counter := data_buffer_counter ^ 1 .U
172- }.otherwise {
173- for (i <- 0 until 64 ) {
174- input_64_buffer(i) := io.requant_data_in.bits.data(i)
159+ io.requant_data_in_gpu.ready := true .B
160+
161+
162+ when(io.requant_data_in.fire) {{
163+ io.requant_data_in_gpu.ready := false .B
164+ for (i <- 0 until 32 ) {
165+ input_32_buffer(i) := io.requant_data_in.bits.data(i)
175166 }
176167 data_buffer_counter := 1 .U
177168 }
178- }
179-
180- when(io.fp8_mode) { // 16 lanes at a time
181- should_compute := data_buffer_counter === 0 .U && requant_data_in_valid_d
182- for (i <- 0 until io.outputnumLanes) {
183- data_buffer(i) := input_32_buffer(i)
184- }
185- }.otherwise {
186- processing_64lane := true .B
187- should_compute := processing_64lane
188- for (i <- 0 until io.outputnumLanes) {
189- val idx = Mux (batch_counter === 0 .U , i.U , (io.outputnumLanes + i).U )
190- data_buffer(i) := input_64_buffer(i)
191- }
192- when(quantize_valid){
193- batch_counter := 1 .U
194- }.otherwise {
195- processing_64lane := false .B
196- data_buffer_counter := 0 .U
197- batch_counter := 0 .U
169+ }.elsewhen(io.requant_data_in_gpu.fire) {
170+ for (i <- 0 until half_lanes) {
171+ val idx = Mux (data_buffer_counter === 0 .U , i.U , (half_lanes + i).U )
172+ input_32_buffer(idx) := io.requant_data_in_gpu.bits.data(i)
198173 }
174+ data_buffer_counter := ~ data_buffer_counter
199175 }
200176
177+ val data_buffer = WireInit (VecInit (Seq .fill(io.outputnumLanes)(0 .U (io.inputdataWidth.W ))))
178+ data_buffer := input_32_buffer
179+
180+ // when(requant_data_in_valid_d || (requant_data_in_gpu_valid_d && (data_buffer_counter === 0.U))) {
181+ // should_compute := true.B
182+ // }
201183
202184 val block_max = Wire (UInt (io.inputdataWidth.W ))
203185 block_max := 0 .U
@@ -356,14 +338,14 @@ class MxRequantizer[T <: Data: Arithmetic](
356338
357339 when(scale_buffer_full) {
358340 io.scaleMem_write.valid := true .B
359- io.scaleMem_write.bits.addr := scale_mem_mvout_base_addr_act +& (scale_write_addr_counter << 5 . U ) // byte address, scale 32B per write
341+ io.scaleMem_write.bits.addr := scale_mem_mvout_base_addr_act + (scale_write_addr_counter << 5 ) // byte address, scale 32B per write
360342 io.scaleMem_write.bits.data := Cat (scale_buffer.reverse)
361343
362344 when(io.scaleMem_write.fire) {
363345 val scale_buffer_packed = Cat (scale_buffer.reverse)
364346 printf(p " [MxScaleGen]: addr= ${scale_write_addr_counter}, data=0x ${Hexadecimal (scale_buffer_packed)}\n " )
365347
366- when(scale_write_addr_counter === ((1 << scaleMem_addr_width ) - 1 ).U ) {
348+ when(scale_write_addr_counter === ((1 << 10 ) - 1 ).U ) {
367349 scale_write_addr_counter := 0 .U
368350 }.otherwise {
369351 scale_write_addr_counter := scale_write_addr_counter + 1 .U
0 commit comments