@@ -55,15 +55,15 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
5555 })
5656
5757
58- def needsBuffering (mx_format : UInt ): Bool = {
59- mx_format === 2 .U // FP8 needs buffering
60- }
61-
62- def extractHalf (data : UInt , use_high_half : Bool ): UInt = {
63- Mux (use_high_half,
64- data(255 , 128 ), // high 128b
65- data(127 , 0 )) // low 128b
66- }
58+ // def needsBuffering(mx_format: UInt): Bool = {
59+ // mx_format === 2.U // FP8 needs buffering
60+ // }
61+
62+ // def extractHalf(data: UInt, use_high_half: Bool): UInt = {
63+ // Mux(use_high_half,
64+ // data(255, 128), // high 128b
65+ // data(127, 0)) // low 128b
66+ // }
6767
6868
6969 val block_size = meshRows* tileRows
@@ -307,13 +307,13 @@ def extractHalf(data: UInt, use_high_half: Bool): UInt = {
307307
308308 // MX format related
309309 // val b_data_buffer = Reg(UInt(sp_width.W))
310- val d_data_buffer = Reg (UInt (sp_width.W ))
310+ // val d_data_buffer = Reg(UInt(sp_width.W))
311311
312- // buffer valid indicators
313- val d_buffer_valid = RegInit (false .B )
312+ // // buffer valid indicators
313+ // val d_buffer_valid = RegInit(false.B)
314314
315- // half buffer indicators
316- val d_buffer_half = RegInit (false .B )
315+ // // half buffer indicators
316+ // val d_buffer_half = RegInit(false.B)
317317
318318
319319 // TODO merge these into one enum
@@ -468,9 +468,9 @@ def extractHalf(data: UInt, use_high_half: Bool): UInt = {
468468 val read_b = b_valid && ! b_read_from_acc && dataBbank === i.U && start_inputting_b && ! accumulate_zeros && b_row_is_not_all_zeros // && !im2col_wire
469469 val read_d = d_valid && ! d_read_from_acc && dataDbank === i.U && start_inputting_d && ! preload_zeros && d_row_is_not_all_zeros // && !im2col_wire
470470
471- val d_needs_sram_read = read_d && ! (needsBuffering(weight_mx_format) && d_buffer_valid && ! d_buffer_half)
471+ // val d_needs_sram_read = read_d && !(needsBuffering(weight_mx_format) && d_buffer_valid && !d_buffer_half)
472472
473- Seq ((read_a, a_ready), (read_b, b_ready), (d_needs_sram_read , d_ready)).foreach { case (rd, r) =>
473+ Seq ((read_a, a_ready), (read_b, b_ready), (read_d , d_ready)).foreach { case (rd, r) =>
474474 when (rd && ! io.srams.read(i).req.ready) {
475475 r := false .B
476476 }
@@ -884,17 +884,17 @@ def extractHalf(data: UInt, use_high_half: Bool): UInt = {
884884 cntl.b_read_from_acc -> accReadValid(cntl.b_bank_acc)
885885 ))
886886
887- // val dataD_valid = cntl.d_garbage || cntl.d_unpadded_cols === 0.U || MuxCase(readValid(cntl.d_bank), Seq(
888- // cntl.preload_zeros -> false.B,
889- // cntl.d_read_from_acc -> accReadValid(cntl.d_bank_acc)
890- // ))
891-
892- val dataD_valid = cntl.d_garbage || cntl.d_unpadded_cols === 0 .U ||
893- Mux (needsBuffering(weight_mx_format) && d_buffer_valid,
894- true .B ,
895- MuxCase (readValid(cntl.d_bank), Seq (
896- cntl.preload_zeros -> false .B ,
897- cntl.d_read_from_acc -> accReadValid(cntl.d_bank_acc))))
887+ val dataD_valid = cntl.d_garbage || cntl.d_unpadded_cols === 0 .U || MuxCase (readValid(cntl.d_bank), Seq (
888+ cntl.preload_zeros -> false .B ,
889+ cntl.d_read_from_acc -> accReadValid(cntl.d_bank_acc)
890+ ))
891+
892+ // val dataD_valid = cntl.d_garbage || cntl.d_unpadded_cols === 0.U ||
893+ // Mux(needsBuffering(weight_mx_format) && d_buffer_valid,
894+ // true.B,
895+ // MuxCase(readValid(cntl.d_bank), Seq(
896+ // cntl.preload_zeros -> false.B,
897+ // cntl.d_read_from_acc -> accReadValid(cntl.d_bank_acc))))
898898
899899 // added for negative bitshift
900900 val preload_zero_counter = RegInit (0 .U (5 .W ))
@@ -904,11 +904,11 @@ def extractHalf(data: UInt, use_high_half: Bool): UInt = {
904904 val dataA_unpadded = Mux (cntl.im2colling, im2ColData, Mux (cntl.a_read_from_acc, accReadData(cntl.a_bank_acc), readData(cntl.a_bank)))
905905 val dataB_unpadded = MuxCase (readData(cntl.b_bank), Seq (cntl.accumulate_zeros -> 0 .U , cntl.b_read_from_acc -> accReadData(cntl.b_bank_acc)))
906906
907- val dataD_from_sram = MuxCase (readData(cntl.d_bank), Seq (cntl.preload_zeros -> 0 .U , cntl.d_read_from_acc -> accReadData(cntl.d_bank_acc)))
907+ // val dataD_from_sram = MuxCase(readData(cntl.d_bank), Seq(cntl.preload_zeros -> 0.U, cntl.d_read_from_acc -> accReadData(cntl.d_bank_acc)))
908908
909- val dataD_unpadded = Mux (needsBuffering(weight_mx_format) && d_buffer_valid, extractHalf(d_data_buffer, d_buffer_half), dataD_from_sram)
909+ // val dataD_unpadded = Mux(needsBuffering(weight_mx_format) && d_buffer_valid, extractHalf(d_data_buffer, d_buffer_half), dataD_from_sram)
910910
911- // val dataD_unpadded = MuxCase(readData(cntl.d_bank), Seq(cntl.preload_zeros -> 0.U, cntl.d_read_from_acc -> accReadData(cntl.d_bank_acc)))
911+ val dataD_unpadded = MuxCase (readData(cntl.d_bank), Seq (cntl.preload_zeros -> 0 .U , cntl.d_read_from_acc -> accReadData(cntl.d_bank_acc)))
912912
913913 val dataA = VecInit (dataA_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.a_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayInputType)))
914914 val dataB = VecInit (dataB_unpadded.asTypeOf(Vec (block_size, accType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.b_unpadded_cols, d, accType.zero)}.map(d => d.asTypeOf(accType).withWidthOf(spatialArrayOutputType)))
@@ -933,25 +933,12 @@ def extractHalf(data: UInt, use_high_half: Bool): UInt = {
933933 }
934934
935935 when (cntl.d_fire && mesh.io.d.fire && ! cntl.d_garbage && ! cntl.preload_zeros && cntl.d_unpadded_cols > 0 .U ) {
936- when (needsBuffering(weight_mx_format)) {
937- when (! d_buffer_valid) {
938- when (! cntl.d_read_from_acc && readValid(cntl.d_bank)) {
939- d_data_buffer := readData(cntl.d_bank)
940- d_buffer_valid := true .B
941- d_buffer_half := false .B
942- }
943- }.elsewhen (! d_buffer_half) {
944- d_buffer_half := true .B
945- }.otherwise {
946- d_buffer_valid := false .B
947- d_buffer_half := false .B
948- }
936+ when (cntl.d_read_from_acc) {
937+ io.acc.read_resp(cntl.d_bank_acc).ready := ! io.acc.read_resp(cntl.d_bank_acc).bits.fromDMA
938+ }.otherwise {
939+ io.srams.read(cntl.d_bank).resp.ready := ! io.srams.read(cntl.d_bank).resp.bits.fromDMA
949940 }
950941 }
951- when (! firing) {
952- d_buffer_valid := false .B
953- d_buffer_half := false .B
954- }
955942 }
956943
957944 if (! ex_read_from_acc) {
@@ -960,6 +947,7 @@ def extractHalf(data: UInt, use_high_half: Bool): UInt = {
960947 }
961948 }
962949
950+
963951 when (cntl_valid) {
964952 // Default inputs
965953 mesh.io.a.valid := cntl.a_fire && dataA_valid
0 commit comments