@@ -61,6 +61,19 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
6161 }
6262 }
6363
64+ val quant_lut = Module (new QuantLut (1 , meshColumns, quantWidth, weightType.getWidth))
65+ quant_lut.io.wraddr := DontCare
66+ quant_lut.io.rdaddr := DontCare
67+ quant_lut.io.wrdata := DontCare
68+ quant_lut.io.wr := false .B
69+
70+ val compute_with_lut = RegInit (false .B )
71+
72+ val shift_for_lut = RegInit (0 .U (log2Up(quantWidth).W ))
73+ when(compute_with_lut) {
74+ shift_for_lut := log2Up(quantWidth).U
75+ }.otherwise {shift_for_lut := 0 .U }
76+
6477 val unrolled_cmd = TransposePreloadUnroller (io.cmd, config, io.counter)
6578
6679 val cmd_q_heads = 3
@@ -83,6 +96,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
8396 val DoConfig = functs(0 ) === CONFIG_CMD
8497 val DoComputes = functs.map(f => f === COMPUTE_AND_FLIP_CMD || f === COMPUTE_AND_STAY_CMD )
8598 val DoPreloads = functs.map(_ === PRELOAD_CMD )
99+ val DoLutPreload1 = functs(0 ) === PRELOAD_LUT1
100+ val DoLutPreload2 = functs(0 ) === PRELOAD_LUT2
86101
87102 val preload_cmd_place = Mux (DoPreloads (0 ), 0 .U , 1 .U )
88103 // val a_address_place = Mux(current_dataflow === Dataflow.WS.id.U, 0.U, Mux(preload_cmd_place === 0.U, 1.U, 2.U))
@@ -436,13 +451,13 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
436451 io.srams.read(i).req.bits.fromDMA := false .B
437452 io.srams.read(i).req.bits.addr := MuxCase (a_address_rs1.sp_row() + a_fire_counter,
438453 Seq (read_b -> (b_address_rs2.sp_row() + b_fire_counter),
439- read_d -> (d_address_rs1.sp_row() + block_size.U - 1 .U - d_fire_counter_mulpre)))
454+ read_d -> (( d_address_rs1.sp_row() + block_size.U - 1 .U - d_fire_counter_mulpre) >> shift_for_lut )))
440455
441456 // TODO this just overrides the previous line. Should we erase the previous line?
442457 when(im2col_en === false .B ) {
443458 io.srams.read(i).req.bits.addr := MuxCase (a_address.sp_row(),
444459 Seq (read_b -> b_address.sp_row(),
445- read_d -> d_address.sp_row()))
460+ read_d -> ( d_address.sp_row() >> shift_for_lut )))
446461 }
447462 } else {
448463 io.srams.read(i).req.valid := false .B
@@ -458,6 +473,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
458473 val read_a_from_acc = a_valid && a_read_from_acc && dataABankAcc === i.U && start_inputting_a && ! multiply_garbage && a_row_is_not_all_zeros && ! (im2col_wire&& im2col_en)
459474 val read_b_from_acc = b_valid && b_read_from_acc && dataBBankAcc === i.U && start_inputting_b && ! accumulate_zeros && b_row_is_not_all_zeros // && !im2col_wire
460475 val read_d_from_acc = d_valid && d_read_from_acc && dataDBankAcc === i.U && start_inputting_d && ! preload_zeros && d_row_is_not_all_zeros // && !im2col_wire
476+ // we do not support LUT and accumulator read for D matrices for now
477+ assert(! (compute_with_lut && read_d_from_acc))
461478
462479 Seq ((read_a_from_acc, a_ready), (read_b_from_acc, b_ready), (read_d_from_acc, d_ready)).foreach { case (rd, r) =>
463480 when(rd && ! io.acc.read_req(i).ready) {
@@ -559,6 +576,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
559576 if (dataflow == Dataflow .BOTH ) {
560577 current_dataflow := config_ex_rs1.dataflow
561578 }
579+ // use dequantization lut path
580+ compute_with_lut := config_ex_rs1.use_lut
562581 }
563582
564583 a_addr_stride := config_ex_rs1.a_stride // TODO this needs to be kept in sync with ROB.scala
@@ -581,6 +600,39 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
581600 cmd.pop := 1 .U
582601 }
583602
603+ .elsewhen(DoLutPreload1 ) {
604+ val preload_lut1 = rs1s(0 ).asTypeOf(new PreloadLutRs1 ())
605+ val preload_lut2 = rs2s(0 ).asTypeOf(new PreloadLutRs2 ())
606+
607+ quant_lut.io.wr := true .B
608+
609+ val (lut_wrcount, lut_wrdone) = Counter (quant_lut.io.wr, 8 )
610+
611+ when(! lut_wrdone) {
612+ quant_lut.io.wraddr(0 ) := Cat (0 .U , lut_wrcount)
613+ quant_lut.io.wrdata(0 ) := preload_lut1.lut_data(lut_wrcount)
614+ }.otherwise {
615+ io.completed := cmd.bits(0 ).rob_id
616+ cmd.pop := 1 .U
617+ }
618+ }
619+ .elsewhen(DoLutPreload2 ) {
620+ val preload_lut1 = rs1s(0 ).asTypeOf(new PreloadLutRs1 ())
621+ val preload_lut2 = rs2s(0 ).asTypeOf(new PreloadLutRs2 ())
622+
623+ quant_lut.io.wr := true .B
624+
625+ val (lut_wrcount, lut_wrdone) = Counter (quant_lut.io.wr, 8 )
626+
627+ when(! lut_wrdone) {
628+ quant_lut.io.wraddr(0 ) := Cat (1 .U , lut_wrcount)
629+ quant_lut.io.wrdata(0 ) := preload_lut2.lut_data(lut_wrcount)
630+ }.otherwise {
631+ io.completed := cmd.bits(0 ).rob_id
632+ cmd.pop := 1 .U
633+ }
634+ }
635+
584636 // Preload
585637 .elsewhen(DoPreloads (0 ) && cmd.valid(1 ) && (raw_hazards_are_impossible.B || ! raw_hazard_pre)) {
586638 perform_single_preload := true .B
@@ -834,8 +886,43 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
834886 val dataD_unpadded = MuxCase (readData(cntl.d_bank), Seq (cntl.preload_zeros -> 0 .U , cntl.d_read_from_acc -> accReadData(cntl.d_bank_acc)))
835887
836888 val dataA = VecInit (dataA_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.a_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayInputType)))
837- val dataB = VecInit (dataB_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.b_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
838- val dataD = VecInit (dataD_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.d_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
889+ val dataB = VecInit (dataB_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.b_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
890+
891+ val dataD_unpadded_dequant = Wire (Vec (meshColumns, UInt ((weightType.getWidth).W )))
892+ val lut_d_to_mesh = cntl.d_fire && dataD_valid && cntl_valid
893+ val (dcount, ddone) = Counter (lut_d_to_mesh, (weightType.getWidth/ quantWidth))
894+
895+ when(compute_with_lut) {
896+ when(dcount === 0 .U ) {
897+ for (i<- 0 until meshColumns) {
898+ quant_lut.io.rdaddr(i) := (dataD_unpadded((meshColumns* quantWidth- 1 ), 0 ).asTypeOf(Vec (meshColumns, UInt (quantWidth.W ))))(i)
899+ dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
900+ }
901+ }.elsewhen(dcount === 1 .U ) {
902+ for (i<- 0 until meshColumns) {
903+ quant_lut.io.rdaddr(i) := (dataD_unpadded((2 * meshColumns* quantWidth- 1 ), (meshColumns* quantWidth- 1 )).asTypeOf(Vec (meshColumns, UInt (quantWidth.W ))))(i)
904+ dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
905+ }
906+ }.elsewhen(dcount === 2 .U ) {
907+ for (i<- 0 until meshColumns) {
908+ quant_lut.io.rdaddr(i) := (dataD_unpadded((3 * meshColumns* quantWidth- 1 ), (2 * meshColumns* quantWidth- 1 )).asTypeOf(Vec (meshColumns, UInt (quantWidth.W ))))(i)
909+ dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
910+ }
911+ }.elsewhen(dcount === 3 .U ) {
912+ for (i<- 0 until meshColumns) {
913+ quant_lut.io.rdaddr(i) := (dataD_unpadded((4 * meshColumns* quantWidth- 1 ), (3 * meshColumns* quantWidth- 1 )).asTypeOf(Vec (meshColumns, UInt (quantWidth.W ))))(i)
914+ dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
915+ }
916+ }.otherwise {
917+ for (i<- 0 until meshColumns) {
918+ dataD_unpadded_dequant(i) := 0 .U
919+ }
920+ }
921+ }.otherwise{
922+ dataD_unpadded_dequant := dataD_unpadded.asTypeOf(dataD_unpadded_dequant)
923+ }
924+
925+ val dataD = VecInit (dataD_unpadded_dequant.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.d_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
839926
840927 // Pop responses off the scratchpad io ports
841928 when (mesh_cntl_signals_q.io.deq.fire) {
0 commit comments