@@ -61,6 +61,19 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
61
61
}
62
62
}
63
63
64
+ val quant_lut = Module (new QuantLut (1 , meshColumns, quantWidth, weightType.getWidth))
65
+ quant_lut.io.wraddr := DontCare
66
+ quant_lut.io.rdaddr := DontCare
67
+ quant_lut.io.wrdata := DontCare
68
+ quant_lut.io.wr := false .B
69
+
70
+ val compute_with_lut = RegInit (false .B )
71
+
72
+ val shift_for_lut = RegInit (0 .U (log2Up(quantWidth).W ))
73
+ when(compute_with_lut) {
74
+ shift_for_lut := log2Up(quantWidth).U
75
+ }.otherwise {shift_for_lut := 0 .U }
76
+
64
77
val unrolled_cmd = TransposePreloadUnroller (io.cmd, config, io.counter)
65
78
66
79
val cmd_q_heads = 3
@@ -83,6 +96,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
83
96
val DoConfig = functs(0 ) === CONFIG_CMD
84
97
val DoComputes = functs.map(f => f === COMPUTE_AND_FLIP_CMD || f === COMPUTE_AND_STAY_CMD )
85
98
val DoPreloads = functs.map(_ === PRELOAD_CMD )
99
+ val DoLutPreload1 = functs(0 ) === PRELOAD_LUT1
100
+ val DoLutPreload2 = functs(0 ) === PRELOAD_LUT2
86
101
87
102
val preload_cmd_place = Mux (DoPreloads (0 ), 0 .U , 1 .U )
88
103
// val a_address_place = Mux(current_dataflow === Dataflow.WS.id.U, 0.U, Mux(preload_cmd_place === 0.U, 1.U, 2.U))
@@ -436,13 +451,13 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
436
451
io.srams.read(i).req.bits.fromDMA := false .B
437
452
io.srams.read(i).req.bits.addr := MuxCase (a_address_rs1.sp_row() + a_fire_counter,
438
453
Seq (read_b -> (b_address_rs2.sp_row() + b_fire_counter),
439
- read_d -> (d_address_rs1.sp_row() + block_size.U - 1 .U - d_fire_counter_mulpre)))
454
+ read_d -> (( d_address_rs1.sp_row() + block_size.U - 1 .U - d_fire_counter_mulpre) >> shift_for_lut )))
440
455
441
456
// TODO this just overrides the previous line. Should we erase the previous line?
442
457
when(im2col_en === false .B ) {
443
458
io.srams.read(i).req.bits.addr := MuxCase (a_address.sp_row(),
444
459
Seq (read_b -> b_address.sp_row(),
445
- read_d -> d_address.sp_row()))
460
+ read_d -> ( d_address.sp_row() >> shift_for_lut )))
446
461
}
447
462
} else {
448
463
io.srams.read(i).req.valid := false .B
@@ -458,6 +473,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
458
473
val read_a_from_acc = a_valid && a_read_from_acc && dataABankAcc === i.U && start_inputting_a && ! multiply_garbage && a_row_is_not_all_zeros && ! (im2col_wire&& im2col_en)
459
474
val read_b_from_acc = b_valid && b_read_from_acc && dataBBankAcc === i.U && start_inputting_b && ! accumulate_zeros && b_row_is_not_all_zeros // && !im2col_wire
460
475
val read_d_from_acc = d_valid && d_read_from_acc && dataDBankAcc === i.U && start_inputting_d && ! preload_zeros && d_row_is_not_all_zeros // && !im2col_wire
476
+ // we do not support LUT and accumulator read for D matrices for now
477
+ assert(! (compute_with_lut && read_d_from_acc))
461
478
462
479
Seq ((read_a_from_acc, a_ready), (read_b_from_acc, b_ready), (read_d_from_acc, d_ready)).foreach { case (rd, r) =>
463
480
when(rd && ! io.acc.read_req(i).ready) {
@@ -559,6 +576,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
559
576
if (dataflow == Dataflow .BOTH ) {
560
577
current_dataflow := config_ex_rs1.dataflow
561
578
}
579
+ // use dequantization lut path
580
+ compute_with_lut := config_ex_rs1.use_lut
562
581
}
563
582
564
583
a_addr_stride := config_ex_rs1.a_stride // TODO this needs to be kept in sync with ROB.scala
@@ -581,6 +600,39 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
581
600
cmd.pop := 1 .U
582
601
}
583
602
603
+ .elsewhen(DoLutPreload1 ) {
604
+ val preload_lut1 = rs1s(0 ).asTypeOf(new PreloadLutRs1 ())
605
+ val preload_lut2 = rs2s(0 ).asTypeOf(new PreloadLutRs2 ())
606
+
607
+ quant_lut.io.wr := true .B
608
+
609
+ val (lut_wrcount, lut_wrdone) = Counter (quant_lut.io.wr, 8 )
610
+
611
+ when(! lut_wrdone) {
612
+ quant_lut.io.wraddr(0 ) := Cat (0 .U , lut_wrcount)
613
+ quant_lut.io.wrdata(0 ) := preload_lut1.lut_data(lut_wrcount)
614
+ }.otherwise {
615
+ io.completed := cmd.bits(0 ).rob_id
616
+ cmd.pop := 1 .U
617
+ }
618
+ }
619
+ .elsewhen(DoLutPreload2 ) {
620
+ val preload_lut1 = rs1s(0 ).asTypeOf(new PreloadLutRs1 ())
621
+ val preload_lut2 = rs2s(0 ).asTypeOf(new PreloadLutRs2 ())
622
+
623
+ quant_lut.io.wr := true .B
624
+
625
+ val (lut_wrcount, lut_wrdone) = Counter (quant_lut.io.wr, 8 )
626
+
627
+ when(! lut_wrdone) {
628
+ quant_lut.io.wraddr(0 ) := Cat (1 .U , lut_wrcount)
629
+ quant_lut.io.wrdata(0 ) := preload_lut2.lut_data(lut_wrcount)
630
+ }.otherwise {
631
+ io.completed := cmd.bits(0 ).rob_id
632
+ cmd.pop := 1 .U
633
+ }
634
+ }
635
+
584
636
// Preload
585
637
.elsewhen(DoPreloads (0 ) && cmd.valid(1 ) && (raw_hazards_are_impossible.B || ! raw_hazard_pre)) {
586
638
perform_single_preload := true .B
@@ -834,8 +886,43 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
834
886
val dataD_unpadded = MuxCase (readData(cntl.d_bank), Seq (cntl.preload_zeros -> 0 .U , cntl.d_read_from_acc -> accReadData(cntl.d_bank_acc)))
835
887
836
888
val dataA = VecInit (dataA_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.a_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayInputType)))
837
- val dataB = VecInit (dataB_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.b_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
838
- val dataD = VecInit (dataD_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.d_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
889
+ val dataB = VecInit (dataB_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.b_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
890
+
891
+ val dataD_unpadded_dequant = Wire (Vec (meshColumns, UInt ((weightType.getWidth).W )))
892
+ val lut_d_to_mesh = cntl.d_fire && dataD_valid && cntl_valid
893
+ val (dcount, ddone) = Counter (lut_d_to_mesh, (weightType.getWidth/ quantWidth))
894
+
895
+ when(compute_with_lut) {
896
+ when(dcount === 0 .U ) {
897
+ for (i<- 0 until meshColumns) {
898
+ quant_lut.io.rdaddr(i) := (dataD_unpadded((meshColumns* quantWidth- 1 ), 0 ).asTypeOf(Vec (meshColumns, UInt (quantWidth.W ))))(i)
899
+ dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
900
+ }
901
+ }.elsewhen(dcount === 1 .U ) {
902
+ for (i<- 0 until meshColumns) {
903
+ quant_lut.io.rdaddr(i) := (dataD_unpadded((2 * meshColumns* quantWidth- 1 ), (meshColumns* quantWidth- 1 )).asTypeOf(Vec (meshColumns, UInt (quantWidth.W ))))(i)
904
+ dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
905
+ }
906
+ }.elsewhen(dcount === 2 .U ) {
907
+ for (i<- 0 until meshColumns) {
908
+ quant_lut.io.rdaddr(i) := (dataD_unpadded((3 * meshColumns* quantWidth- 1 ), (2 * meshColumns* quantWidth- 1 )).asTypeOf(Vec (meshColumns, UInt (quantWidth.W ))))(i)
909
+ dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
910
+ }
911
+ }.elsewhen(dcount === 3 .U ) {
912
+ for (i<- 0 until meshColumns) {
913
+ quant_lut.io.rdaddr(i) := (dataD_unpadded((4 * meshColumns* quantWidth- 1 ), (3 * meshColumns* quantWidth- 1 )).asTypeOf(Vec (meshColumns, UInt (quantWidth.W ))))(i)
914
+ dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
915
+ }
916
+ }.otherwise {
917
+ for (i<- 0 until meshColumns) {
918
+ dataD_unpadded_dequant(i) := 0 .U
919
+ }
920
+ }
921
+ }.otherwise{
922
+ dataD_unpadded_dequant := dataD_unpadded.asTypeOf(dataD_unpadded_dequant)
923
+ }
924
+
925
+ val dataD = VecInit (dataD_unpadded_dequant.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.d_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
839
926
840
927
// Pop responses off the scratchpad io ports
841
928
when (mesh_cntl_signals_q.io.deq.fire) {
0 commit comments