Skip to content

Commit e744e53

Browse files
Amanda ShiAmanda Shi
authored andcommitted
change mx format definition for consistency
1 parent af9b8b4 commit e744e53

File tree

7 files changed

+16
-34
lines changed

7 files changed

+16
-34
lines changed

src/main/scala/gemmini/Controller.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
320320
mx_requantizer.get.io.requant_data_in.valid := false.B
321321
mx_requantizer.get.io.requant_data_in.bits := DontCare
322322

323-
when(ex_controller.io.output_MxFormat === 2.U){
323+
when(ex_controller.io.output_MxFormat === 0.U){
324324
mx_requantizer.get.io.fp8_mode := true.B
325325
}.otherwise{
326326
mx_requantizer.get.io.fp8_mode := false.B}
@@ -375,11 +375,11 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
375375
mx_requantizer.get.io.requant_data_in.bits.address :=
376376
Cat(first_valid_addr, 0.U(log2Ceil(outer.config.sp_banks).W))
377377

378-
when(ex_controller.io.output_MxFormat === 0.U) {
378+
when(ex_controller.io.output_MxFormat === 2.U) {
379379
mx_requantizer.get.io.requant_data_in.bits.dataType := RequantizerDataType.FP4
380380
}.elsewhen(ex_controller.io.output_MxFormat === 1.U) {
381381
mx_requantizer.get.io.requant_data_in.bits.dataType := RequantizerDataType.FP6
382-
}.elsewhen(ex_controller.io.output_MxFormat === 2.U) {
382+
}.elsewhen(ex_controller.io.output_MxFormat === 0.U) {
383383
mx_requantizer.get.io.requant_data_in.bits.dataType := RequantizerDataType.FP8
384384
}
385385

src/main/scala/gemmini/ExecuteController.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
5252
val counter = new CounterEventIO()
5353
val b_fire = Output(Bool())
5454
val a_fire = Output(Bool())
55-
val scale_mem_mvout_base_addr_act = Output(UInt(33.W))
55+
val scale_mem_mvout_base_addr_act = Output(UInt(32.W))
5656
val scaleMemCntl = Output(new ScalingFactorCntl(meshRows*tileRows))
5757
})
5858

src/main/scala/gemmini/GemminiISA.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ object GemminiISA {
226226
val b_transpose = UInt(CONFIG_EX_RS1_B_TRANSPOSE_WIDTH.W)
227227
val a_transpose = UInt(CONFIG_EX_RS1_A_TRANSPOSE_WIDTH.W)
228228
val set_only_strides = UInt(CONFIG_EX_RS1_SET_ONLY_STRIDES_WIDTH.W)
229-
//val enable_mxquant = UInt(CONFIG_EX_RS1_ENABLE_MXQUANT_WIDTH.W)
229+
val _spacer0 = UInt(1.W)
230230
val uselut = UInt(CONFIG_EX_RS1_LUT_ENABLE_WIDTH.W)
231231
val activation = UInt(CONFIG_EX_RS1_ACTIVATION_WIDTH.W)
232232
val dataflow = UInt(CONFIG_EX_RS1_DATAFLOW_WIDTH.W)

src/main/scala/gemmini/MxConfigFragments.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ case class GemminiLUTConfig(
4545
)
4646

4747
object RequantizerDataType extends ChiselEnum {
48-
val FP4, FP6, FP8 = Value
48+
val FP8, FP6, FP4 = Value
4949

5050
def widthBits(x: Type): UInt = {
5151
Mux(x === FP4, 4.U(4.W), 8.U(4.W))

src/main/scala/gemmini/MxRequantizer.scala

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ import chisel3.util._
55

66
object MxFloatFormat {
77
// Format encoding
8-
val FP4 = 0.U(2.W)
8+
val FP8 = 0.U(2.W)
99
val FP6 = 1.U(2.W)
10-
val FP8 = 2.U(2.W)
10+
val FP4 = 2.U(2.W)
1111

1212
def apply(bits: UInt): (UInt, UInt, UInt, UInt) = {
1313
val exp_bits = MuxLookup(bits, 4.U)(Seq(
@@ -336,20 +336,17 @@ class MxRequantizer[T <: Data: Arithmetic](
336336
scale_buffer_full := false.B
337337
}
338338

339-
when(scale_buffer_full) {
340-
io.scaleMem_write.valid := true.B
341-
io.scaleMem_write.bits.addr := scale_mem_mvout_base_addr_act + (scale_write_addr_counter << 5) //byte address, scale 32B per write
342-
io.scaleMem_write.bits.data := Cat(scale_buffer.reverse)
343-
344-
when(io.scaleMem_write.fire) {
345-
val scale_buffer_packed = Cat(scale_buffer.reverse)
346-
printf(p"[MxScaleGen]: addr=${scale_write_addr_counter}, data=0x${Hexadecimal(scale_buffer_packed)}\n")
347-
339+
when(io.scaleMem_write.fire) {
348340
when(scale_write_addr_counter === ((1 << 10) - 1).U) {
349341
scale_write_addr_counter := 0.U
350342
}.otherwise {
351343
scale_write_addr_counter := scale_write_addr_counter + 1.U
352344
}
353-
}
345+
}
346+
347+
when(scale_buffer_full) {
348+
io.scaleMem_write.valid := true.B
349+
io.scaleMem_write.bits.addr := scale_mem_mvout_base_addr_act + (scale_write_addr_counter << 5) //byte address, scale 32B per write
350+
io.scaleMem_write.bits.data := Cat(scale_buffer.reverse)
354351
}
355352
}

src/main/scala/gemmini/ReservationStation.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G
311311
}
312312

313313
val is_load = funct === LOAD_CMD || funct === LOAD2_CMD || funct === LOAD3_CMD || (funct === CONFIG_CMD && config_cmd_type === CONFIG_LOAD)
314-
val is_ex = funct === PRELOAD_CMD || funct_is_compute || (funct === CONFIG_CMD && config_cmd_type === CONFIG_EX)
314+
val is_ex = funct === PRELOAD_CMD || funct_is_compute || (funct === CONFIG_CMD && config_cmd_type === CONFIG_EX) || (funct === CONFIG_SCALE_MEM)
315315
val is_store = funct === STORE_CMD || funct === STORE_SPAD_CMD || (funct === CONFIG_CMD && (config_cmd_type === CONFIG_STORE || config_cmd_type === CONFIG_NORM))
316316
val is_norm = funct === CONFIG_CMD && config_cmd_type === CONFIG_NORM // normalization commands are a subset of store commands, so they still go in the store queue
317317

src/main/scala/gemmini/Scratchpad.scala

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -124,21 +124,6 @@ class ScratchpadBank(n: Int, w: Int, aligned_to: Int, single_ported: Boolean, us
124124
val input_mx_format = io.read.req.bits.input_mx_format
125125

126126

127-
val bits_per_element = MuxLookup(input_mx_format, 8.U)(Seq(
128-
0.U -> 8.U, // FP8
129-
1.U -> 4.U, // FP6
130-
2.U -> 4.U // FP4
131-
))
132-
133-
val elements_per_row = (w / 8).U
134-
val total_bits_needed = elements_per_row * bits_per_element
135-
val bytes_needed = (total_bits_needed + 7.U) >> 3.U
136-
val addresses_needed = (bytes_needed + (w/8 - 1).U) / (w/8).U
137-
138-
139-
140-
141-
142127
// Make a queue which buffers the result of an SRAM read if it can't immediately be consumed
143128
val q = Module(new Queue(new ScratchpadReadResp(w), 1, true, true))
144129
val q_will_be_empty = (q.io.count +& q.io.enq.fire) - q.io.deq.fire === 0.U

0 commit comments

Comments
 (0)