@@ -61,6 +61,21 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
61
61
}
62
62
}
63
63
64
+ val quant_lut = Module (new QuantLut (1 ,8 ,4 ,16 ))
65
+ // val rd_from_lut = quant_lut.io.rddata
66
+ dontTouch(quant_lut.io.rddata)
67
+ quant_lut.io.wraddr := DontCare
68
+ quant_lut.io.rdaddr := DontCare
69
+ quant_lut.io.wrdata := DontCare
70
+ quant_lut.io.wr := false .B
71
+
72
+ val compute_with_lut = RegInit (false .B )
73
+
74
+ val shift_for_lut = RegInit (0 .U (2 .W ))
75
+ when(compute_with_lut) {
76
+ shift_for_lut := 2 .U
77
+ }
78
+
64
79
val unrolled_cmd = TransposePreloadUnroller (io.cmd, config, io.counter)
65
80
66
81
val cmd_q_heads = 3
@@ -83,6 +98,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
83
98
val DoConfig = functs(0 ) === CONFIG_CMD
84
99
val DoComputes = functs.map(f => f === COMPUTE_AND_FLIP_CMD || f === COMPUTE_AND_STAY_CMD )
85
100
val DoPreloads = functs.map(_ === PRELOAD_CMD )
101
+ val DoLutPreload1 = functs(0 ) === PRELOAD_LUT1
102
+ val DoLutPreload2 = functs(0 ) === PRELOAD_LUT2
86
103
87
104
val preload_cmd_place = Mux (DoPreloads (0 ), 0 .U , 1 .U )
88
105
// val a_address_place = Mux(current_dataflow === Dataflow.WS.id.U, 0.U, Mux(preload_cmd_place === 0.U, 1.U, 2.U))
@@ -436,13 +453,13 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
436
453
io.srams.read(i).req.bits.fromDMA := false .B
437
454
io.srams.read(i).req.bits.addr := MuxCase (a_address_rs1.sp_row() + a_fire_counter,
438
455
Seq (read_b -> (b_address_rs2.sp_row() + b_fire_counter),
439
- read_d -> (d_address_rs1.sp_row() + block_size.U - 1 .U - d_fire_counter_mulpre)))
456
+ read_d -> (( d_address_rs1.sp_row() + block_size.U - 1 .U - d_fire_counter_mulpre) >> shift_for_lut )))
440
457
441
458
// TODO this just overrides the previous line. Should we erase the previous line?
442
459
when(im2col_en === false .B ) {
443
460
io.srams.read(i).req.bits.addr := MuxCase (a_address.sp_row(),
444
461
Seq (read_b -> b_address.sp_row(),
445
- read_d -> d_address.sp_row()))
462
+ read_d -> ( d_address.sp_row() >> shift_for_lut )))
446
463
}
447
464
} else {
448
465
io.srams.read(i).req.valid := false .B
@@ -458,6 +475,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
458
475
val read_a_from_acc = a_valid && a_read_from_acc && dataABankAcc === i.U && start_inputting_a && ! multiply_garbage && a_row_is_not_all_zeros && ! (im2col_wire&& im2col_en)
459
476
val read_b_from_acc = b_valid && b_read_from_acc && dataBBankAcc === i.U && start_inputting_b && ! accumulate_zeros && b_row_is_not_all_zeros // && !im2col_wire
460
477
val read_d_from_acc = d_valid && d_read_from_acc && dataDBankAcc === i.U && start_inputting_d && ! preload_zeros && d_row_is_not_all_zeros // && !im2col_wire
478
+ // we do not support LUT and accumulator read for D matrices for now
479
+ assert(! (compute_with_lut && read_d_from_acc))
461
480
462
481
Seq ((read_a_from_acc, a_ready), (read_b_from_acc, b_ready), (read_d_from_acc, d_ready)).foreach { case (rd, r) =>
463
482
when(rd && ! io.acc.read_req(i).ready) {
@@ -559,6 +578,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
559
578
if (dataflow == Dataflow .BOTH ) {
560
579
current_dataflow := config_ex_rs1.dataflow
561
580
}
581
+ // use dequantization lut path
582
+ compute_with_lut := config_ex_rs1.use_lut
562
583
}
563
584
564
585
a_addr_stride := config_ex_rs1.a_stride // TODO this needs to be kept in sync with ROB.scala
@@ -581,6 +602,39 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
581
602
cmd.pop := 1 .U
582
603
}
583
604
605
+ .elsewhen(DoLutPreload1 ) {
606
+ val preload_lut1 = rs1s(0 ).asTypeOf(new PreloadLutRs1 ())
607
+ val preload_lut2 = rs2s(0 ).asTypeOf(new PreloadLutRs2 ())
608
+
609
+ quant_lut.io.wr := true .B
610
+
611
+ val (lut_wrcount, lut_wrdone) = Counter (quant_lut.io.wr, 8 )
612
+
613
+ when(! lut_wrdone) {
614
+ quant_lut.io.wraddr(0 ) := Cat (0 .U , lut_wrcount)
615
+ quant_lut.io.wrdata(0 ) := preload_lut1.lut_data(lut_wrcount)
616
+ }.otherwise {
617
+ io.completed := cmd.bits(0 ).rob_id
618
+ cmd.pop := 1 .U
619
+ }
620
+ }
621
+ .elsewhen(DoLutPreload2 ) {
622
+ val preload_lut1 = rs1s(0 ).asTypeOf(new PreloadLutRs1 ())
623
+ val preload_lut2 = rs2s(0 ).asTypeOf(new PreloadLutRs2 ())
624
+
625
+ quant_lut.io.wr := true .B
626
+
627
+ val (lut_wrcount, lut_wrdone) = Counter (quant_lut.io.wr, 8 )
628
+
629
+ when(! lut_wrdone) {
630
+ quant_lut.io.wraddr(0 ) := Cat (1 .U , lut_wrcount)
631
+ quant_lut.io.wrdata(0 ) := preload_lut2.lut_data(lut_wrcount)
632
+ }.otherwise {
633
+ io.completed := cmd.bits(0 ).rob_id
634
+ cmd.pop := 1 .U
635
+ }
636
+ }
637
+
584
638
// Preload
585
639
.elsewhen(DoPreloads (0 ) && cmd.valid(1 ) && (raw_hazards_are_impossible.B || ! raw_hazard_pre)) {
586
640
perform_single_preload := true .B
@@ -835,7 +889,44 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
835
889
836
890
val dataA = VecInit (dataA_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.a_unpadded_cols, d, inputType.zero)})
837
891
val dataB = VecInit (dataB_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.b_unpadded_cols, d, inputType.zero)})
838
- val dataD = VecInit (dataD_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.d_unpadded_cols, d, inputType.zero)})
892
+
893
+ val dataD_unpadded_dequant = Wire (Vec (8 , UInt (16 .W )))
894
+ dontTouch(dataD_unpadded_dequant)
895
+
896
+ val lut_d_to_mesh = cntl.d_fire && dataD_valid && cntl_valid
897
+ val (dcount, ddone) = Counter (lut_d_to_mesh, 4 )
898
+
899
+ when(compute_with_lut) {
900
+ when(dcount === 0 .U ) {
901
+ for (i<- 0 until 8 ) {
902
+ quant_lut.io.rdaddr(i) := (dataD_unpadded(31 ,0 ).asTypeOf(Vec (8 ,UInt (4 .W ))))(i)
903
+ dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
904
+ }
905
+ }.elsewhen(dcount === 1 .U ) {
906
+ for (i<- 0 until 8 ) {
907
+ quant_lut.io.rdaddr(i) := (dataD_unpadded(63 ,31 ).asTypeOf(Vec (8 ,UInt (4 .W ))))(i)
908
+ dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
909
+ }
910
+ }.elsewhen(dcount === 2 .U ) {
911
+ for (i<- 0 until 8 ) {
912
+ quant_lut.io.rdaddr(i) := (dataD_unpadded(95 ,64 ).asTypeOf(Vec (8 ,UInt (4 .W ))))(i)
913
+ dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
914
+ }
915
+ }.elsewhen(dcount === 3 .U ) {
916
+ for (i<- 0 until 8 ) {
917
+ quant_lut.io.rdaddr(i) := (dataD_unpadded(127 ,96 ).asTypeOf(Vec (8 ,UInt (4 .W ))))(i)
918
+ dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
919
+ }
920
+ }.otherwise {
921
+ for (i<- 0 until 8 ) {
922
+ dataD_unpadded_dequant(i) := 0 .U
923
+ }
924
+ }
925
+ }.otherwise{
926
+ dataD_unpadded_dequant := dataD_unpadded.asTypeOf(dataD_unpadded_dequant)
927
+ }
928
+
929
+ val dataD = VecInit (dataD_unpadded_dequant.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.d_unpadded_cols, d, inputType.zero)})
839
930
840
931
// Pop responses off the scratchpad io ports
841
932
when (mesh_cntl_signals_q.io.deq.fire) {
0 commit comments