Skip to content

Commit ed2ddd4

Browse files
Amanda ShiAmanda Shi
authored andcommitted
remove the half row reading logic from spad
1 parent 7cfba71 commit ed2ddd4

File tree

4 files changed

+41
-70
lines changed

4 files changed

+41
-70
lines changed

src/main/scala/gemmini/Controller.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,9 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
248248
}
249249
}
250250
if(outer.config.lut.isDefined){
251-
mx_requantizer.get.io.lut_write_0 := DontCare
252-
mx_requantizer.get.io.lut_write_1 := DontCare
253-
mx_requantizer.get.io.lut_write_2 := DontCare
251+
mx_requantizer.get.io.lut0_write := DontCare
252+
mx_requantizer.get.io.lut1_write := DontCare
253+
mx_requantizer.get.io.lut2_write := DontCare
254254
}
255255
}
256256

src/main/scala/gemmini/ExecuteController.scala

Lines changed: 35 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,15 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
5555
})
5656

5757

58-
def needsBuffering(mx_format: UInt): Bool = {
59-
mx_format === 2.U // FP8 needs buffering
60-
}
61-
62-
def extractHalf(data: UInt, use_high_half: Bool): UInt = {
63-
Mux(use_high_half,
64-
data(255, 128), // high 128b
65-
data(127, 0)) // low 128b
66-
}
58+
// def needsBuffering(mx_format: UInt): Bool = {
59+
// mx_format === 2.U // FP8 needs buffering
60+
// }
61+
62+
// def extractHalf(data: UInt, use_high_half: Bool): UInt = {
63+
// Mux(use_high_half,
64+
// data(255, 128), // high 128b
65+
// data(127, 0)) // low 128b
66+
// }
6767

6868

6969
val block_size = meshRows*tileRows
@@ -307,13 +307,13 @@ def extractHalf(data: UInt, use_high_half: Bool): UInt = {
307307

308308
//MX format related
309309
//val b_data_buffer = Reg(UInt(sp_width.W))
310-
val d_data_buffer = Reg(UInt(sp_width.W))
310+
// val d_data_buffer = Reg(UInt(sp_width.W))
311311

312-
// buffer valid indicators
313-
val d_buffer_valid = RegInit(false.B)
312+
// // buffer valid indicators
313+
// val d_buffer_valid = RegInit(false.B)
314314

315-
// half buffer indicators
316-
val d_buffer_half = RegInit(false.B)
315+
// // half buffer indicators
316+
// val d_buffer_half = RegInit(false.B)
317317

318318

319319
// TODO merge these into one enum
@@ -468,9 +468,9 @@ def extractHalf(data: UInt, use_high_half: Bool): UInt = {
468468
val read_b = b_valid && !b_read_from_acc && dataBbank === i.U && start_inputting_b && !accumulate_zeros && b_row_is_not_all_zeros //&& !im2col_wire
469469
val read_d = d_valid && !d_read_from_acc && dataDbank === i.U && start_inputting_d && !preload_zeros && d_row_is_not_all_zeros //&& !im2col_wire
470470

471-
val d_needs_sram_read = read_d && !(needsBuffering(weight_mx_format) && d_buffer_valid && !d_buffer_half)
471+
//val d_needs_sram_read = read_d && !(needsBuffering(weight_mx_format) && d_buffer_valid && !d_buffer_half)
472472

473-
Seq((read_a, a_ready), (read_b, b_ready), (d_needs_sram_read, d_ready)).foreach { case (rd, r) =>
473+
Seq((read_a, a_ready), (read_b, b_ready), (read_d, d_ready)).foreach { case (rd, r) =>
474474
when (rd && !io.srams.read(i).req.ready) {
475475
r := false.B
476476
}
@@ -884,17 +884,17 @@ def extractHalf(data: UInt, use_high_half: Bool): UInt = {
884884
cntl.b_read_from_acc -> accReadValid(cntl.b_bank_acc)
885885
))
886886

887-
//val dataD_valid = cntl.d_garbage || cntl.d_unpadded_cols === 0.U || MuxCase(readValid(cntl.d_bank), Seq(
888-
// cntl.preload_zeros -> false.B,
889-
// cntl.d_read_from_acc -> accReadValid(cntl.d_bank_acc)
890-
//))
891-
892-
val dataD_valid = cntl.d_garbage || cntl.d_unpadded_cols === 0.U ||
893-
Mux(needsBuffering(weight_mx_format) && d_buffer_valid,
894-
true.B,
895-
MuxCase(readValid(cntl.d_bank), Seq(
896-
cntl.preload_zeros -> false.B,
897-
cntl.d_read_from_acc -> accReadValid(cntl.d_bank_acc))))
887+
val dataD_valid = cntl.d_garbage || cntl.d_unpadded_cols === 0.U || MuxCase(readValid(cntl.d_bank), Seq(
888+
cntl.preload_zeros -> false.B,
889+
cntl.d_read_from_acc -> accReadValid(cntl.d_bank_acc)
890+
))
891+
892+
// val dataD_valid = cntl.d_garbage || cntl.d_unpadded_cols === 0.U ||
893+
// Mux(needsBuffering(weight_mx_format) && d_buffer_valid,
894+
// true.B,
895+
// MuxCase(readValid(cntl.d_bank), Seq(
896+
// cntl.preload_zeros -> false.B,
897+
// cntl.d_read_from_acc -> accReadValid(cntl.d_bank_acc))))
898898

899899
//added for negative bitshift
900900
val preload_zero_counter = RegInit(0.U(5.W))
@@ -904,11 +904,11 @@ def extractHalf(data: UInt, use_high_half: Bool): UInt = {
904904
val dataA_unpadded = Mux(cntl.im2colling, im2ColData, Mux(cntl.a_read_from_acc, accReadData(cntl.a_bank_acc), readData(cntl.a_bank)))
905905
val dataB_unpadded = MuxCase(readData(cntl.b_bank), Seq(cntl.accumulate_zeros -> 0.U, cntl.b_read_from_acc -> accReadData(cntl.b_bank_acc)))
906906

907-
val dataD_from_sram = MuxCase(readData(cntl.d_bank), Seq(cntl.preload_zeros -> 0.U, cntl.d_read_from_acc -> accReadData(cntl.d_bank_acc)))
907+
// val dataD_from_sram = MuxCase(readData(cntl.d_bank), Seq(cntl.preload_zeros -> 0.U, cntl.d_read_from_acc -> accReadData(cntl.d_bank_acc)))
908908

909-
val dataD_unpadded = Mux(needsBuffering(weight_mx_format) && d_buffer_valid, extractHalf(d_data_buffer, d_buffer_half), dataD_from_sram)
909+
// val dataD_unpadded = Mux(needsBuffering(weight_mx_format) && d_buffer_valid, extractHalf(d_data_buffer, d_buffer_half), dataD_from_sram)
910910

911-
//val dataD_unpadded = MuxCase(readData(cntl.d_bank), Seq(cntl.preload_zeros -> 0.U, cntl.d_read_from_acc -> accReadData(cntl.d_bank_acc)))
911+
val dataD_unpadded = MuxCase(readData(cntl.d_bank), Seq(cntl.preload_zeros -> 0.U, cntl.d_read_from_acc -> accReadData(cntl.d_bank_acc)))
912912

913913
val dataA = VecInit(dataA_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.a_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayInputType)))
914914
val dataB = VecInit(dataB_unpadded.asTypeOf(Vec(block_size, accType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.b_unpadded_cols, d, accType.zero)}.map(d => d.asTypeOf(accType).withWidthOf(spatialArrayOutputType)))
@@ -933,25 +933,12 @@ def extractHalf(data: UInt, use_high_half: Bool): UInt = {
933933
}
934934

935935
when (cntl.d_fire && mesh.io.d.fire && !cntl.d_garbage && !cntl.preload_zeros && cntl.d_unpadded_cols > 0.U) {
936-
when (needsBuffering(weight_mx_format)) {
937-
when (!d_buffer_valid) {
938-
when (!cntl.d_read_from_acc && readValid(cntl.d_bank)) {
939-
d_data_buffer := readData(cntl.d_bank)
940-
d_buffer_valid := true.B
941-
d_buffer_half := false.B
942-
}
943-
}.elsewhen (!d_buffer_half) {
944-
d_buffer_half := true.B
945-
}.otherwise {
946-
d_buffer_valid := false.B
947-
d_buffer_half := false.B
948-
}
936+
when (cntl.d_read_from_acc) {
937+
io.acc.read_resp(cntl.d_bank_acc).ready := !io.acc.read_resp(cntl.d_bank_acc).bits.fromDMA
938+
}.otherwise {
939+
io.srams.read(cntl.d_bank).resp.ready := !io.srams.read(cntl.d_bank).resp.bits.fromDMA
949940
}
950941
}
951-
when (!firing) {
952-
d_buffer_valid := false.B
953-
d_buffer_half := false.B
954-
}
955942
}
956943

957944
if (!ex_read_from_acc) {
@@ -960,6 +947,7 @@ def extractHalf(data: UInt, use_high_half: Bool): UInt = {
960947
}
961948
}
962949

950+
963951
when (cntl_valid) {
964952
// Default inputs
965953
mesh.io.a.valid := cntl.a_fire && dataA_valid

src/main/scala/gemmini/MxRequantizer.scala

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -288,18 +288,11 @@ class MxRequantizer[T <: Data: Arithmetic](
288288
// quantLut.io.lut_write.ready := false.B
289289
// }
290290

291-
<<<<<<< HEAD
292-
quantLut.io.lut_write <> io.lut0_write
293-
io.lut1_write.ready := false.B
294-
io.lut2_write.ready := false.B
295-
296-
=======
297-
quantLut.io.lut_write_weight <> io.lut_write_0
298-
quantLut.io.lut_write_act_in <> io.lut_write_1
299-
quantLut.io.lut_write_act_out <> io.lut_write_2
291+
quantLut.io.lut_write_weight <> io.lut0_write
292+
quantLut.io.lut_write_act_in <> io.lut1_write
293+
quantLut.io.lut_write_act_out <> io.lut2_write
300294
val quant_fp6_buffer = RegInit(VecInit(Seq.fill(io.outputnumLanes)(0.U(6.W))))
301295
val quant_fp6_hang = RegInit(false.B)
302-
>>>>>>> e1e04af (change the QuantLut as double buffer)
303296
when(quantize_valid && (total_bits_per_element === 6.U)) {
304297
when{quantLut.io.lut_write_act_out.ready}{ //hang here when write is finished
305298
for (i <- 0 until io.outputnumLanes) {

src/main/scala/gemmini/QuantLut.scala

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,11 @@ class QuantLutIO(
1616
sp_width_projected: Int,
1717
iterator_bitwidth: Int,
1818
) extends Bundle {
19-
<<<<<<< HEAD
20-
val lutReadEnable = Output(Bool())
21-
val lut_write = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig))) //input
22-
val quant_fp6 = Flipped(Valid(Vec(outputnumLanes, UInt(lutConfig.rdataWidth.W)))) //input
23-
val projected_data = Valid(Vec(outputnumLanes, UInt(lutConfig.raddrWidth.W))) //output
24-
// val spad_projected_data = Flipped(Decoupled(Vec(sp_banks, new ScratchpadReadIO(sp_bank_entries, sp_width_projected))))
25-
// val spad_deprojected_data = Decoupled(Vec(sp_banks, new ScratchpadReadIO(sp_bank_entries, sp_width)))
26-
=======
2719
val lut_write_weight = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig))) //input
2820
val lut_write_act_in = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig))) //input
2921
val lut_write_act_out = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig))) //input
3022
val quant_fp6 = Flipped(Valid(Vec(outputnumLanes, UInt(lutConfig.rdataWidth.W)))) //input
3123
val projected_data = Valid(Vec(outputnumLanes, UInt(lutConfig.raddrWidth.W))) //output
32-
>>>>>>> e1e04af (change the QuantLut as double buffer)
3324
val spad_projected_data = Vec(sp_banks, new ScratchpadReadIO(sp_bank_entries, sp_width_projected))
3425
val spad_deprojected_data = Vec(sp_banks, Flipped(new ScratchpadReadIO(sp_bank_entries, sp_width)))
3526
val counter_j = Input(UInt(iterator_bitwidth.W))
@@ -159,7 +150,6 @@ class QuantLut(
159150
val lutCache_act_out_buffer_0_read_enable = RegInit(false.B)
160151
val lutCache_act_out_buffer_1_read_enable = RegInit(false.B)
161152
val lutCache_act_out_buffer_select = RegInit(false.B)
162-
val counter_k_reg = RegNext(io.counter_k)
163153

164154
when(io.lut_write_act_out.fire){
165155
when(lutCache_act_out_flag === false.B){

0 commit comments

Comments
 (0)