Skip to content

Commit 97798dc

Browse files
committed
Lut for nuq quantization
1 parent c16f815 commit 97798dc

File tree

6 files changed

+178
-6
lines changed

6 files changed

+178
-6
lines changed

src/main/scala/gemmini/Configs.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ object GemminiConfigs {
239239

240240
val leanPrintfConfig = defaultConfig.copy(dataflow=Dataflow.WS, max_in_flight_mem_reqs = 64, acc_read_full_width = false, ex_read_from_acc = false, ex_write_to_spad = false, hardcode_d_to_garbage_addr = true, use_firesim_simulation_counters=true)
241241

242+
val lutLeanConfig = defaultConfig.copy(dataflow=Dataflow.WS, inputType=SInt(16.W), meshRows=8, meshColumns=8, max_in_flight_mem_reqs = 64, acc_read_full_width = false, ex_read_from_acc = false, ex_write_to_spad = false, hardcode_d_to_garbage_addr = true)
243+
242244
}
243245

244246
/**
@@ -368,3 +370,18 @@ class DualGemminiConfig extends Config((site, here, up) => {
368370
up(BuildRoCC) ++ Seq(int_fn, fp_fn)
369371
}
370372
})
373+
374+
/**
375+
* Mixin which sets the default lut lean parameters for a systolic array accelerator.
376+
*/
377+
class LutLeanGemminiConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
378+
gemminiConfig: GemminiArrayConfig[T,U,V] = GemminiConfigs.lutLeanConfig
379+
) extends Config((site, here, up) => {
380+
case BuildRoCC => up(BuildRoCC) ++ Seq(
381+
(p: Parameters) => {
382+
implicit val q = p
383+
val gemmini = LazyModule(new Gemmini(gemminiConfig))
384+
gemmini
385+
}
386+
)
387+
})

src/main/scala/gemmini/ConfigsFP.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,14 @@ object GemminiFPConfigs {
109109
mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")),
110110
)
111111

112+
//FP16 Half Precision Configuration for LUT with 8x8 array
113+
val LutFP16DefaultConfig = defaultFPConfig.copy(inputType = Float(5, 11), spatialArrayOutputType = Float(5, 11), accType = Float(8, 24),
114+
meshRows = 8, meshColumns = 8,
115+
tile_latency = 2,
116+
mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))")),
117+
mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))")),
118+
)
119+
112120
}
113121

114122

@@ -171,3 +179,14 @@ class GemminiBF16Default8Config extends Config((site, here, up) => {
171179
)
172180
})
173181

182+
//===========FP16 LUT-based Default Config=========
183+
class GemminiLutFP16DefaultConfig extends Config((site, here, up) => {
184+
case BuildRoCC => Seq(
185+
(p: Parameters) => {
186+
implicit val q = p
187+
implicit val v = implicitly[ValName]
188+
LazyModule(new Gemmini(GemminiFPConfigs.LutFP16DefaultConfig))
189+
}
190+
)
191+
})
192+

src/main/scala/gemmini/ExecuteController.scala

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,21 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
6161
}
6262
}
6363

64+
val quant_lut = Module(new QuantLut(1,8,4,16))
65+
//val rd_from_lut = quant_lut.io.rddata
66+
dontTouch(quant_lut.io.rddata)
67+
quant_lut.io.wraddr := DontCare
68+
quant_lut.io.rdaddr := DontCare
69+
quant_lut.io.wrdata := DontCare
70+
quant_lut.io.wr := false.B
71+
72+
val compute_with_lut = RegInit(false.B)
73+
74+
val shift_for_lut = RegInit(0.U(2.W))
75+
when(compute_with_lut) {
76+
shift_for_lut := 2.U
77+
}
78+
6479
val unrolled_cmd = TransposePreloadUnroller(io.cmd, config, io.counter)
6580

6681
val cmd_q_heads = 3
@@ -83,6 +98,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
8398
val DoConfig = functs(0) === CONFIG_CMD
8499
val DoComputes = functs.map(f => f === COMPUTE_AND_FLIP_CMD || f === COMPUTE_AND_STAY_CMD)
85100
val DoPreloads = functs.map(_ === PRELOAD_CMD)
101+
val DoLutPreload1 = functs(0) === PRELOAD_LUT1
102+
val DoLutPreload2 = functs(0) === PRELOAD_LUT2
86103

87104
val preload_cmd_place = Mux(DoPreloads(0), 0.U, 1.U)
88105
// val a_address_place = Mux(current_dataflow === Dataflow.WS.id.U, 0.U, Mux(preload_cmd_place === 0.U, 1.U, 2.U))
@@ -436,13 +453,13 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
436453
io.srams.read(i).req.bits.fromDMA := false.B
437454
io.srams.read(i).req.bits.addr := MuxCase(a_address_rs1.sp_row() + a_fire_counter,
438455
Seq(read_b -> (b_address_rs2.sp_row() + b_fire_counter),
439-
read_d -> (d_address_rs1.sp_row() + block_size.U - 1.U - d_fire_counter_mulpre)))
456+
read_d -> ((d_address_rs1.sp_row() + block_size.U - 1.U - d_fire_counter_mulpre) >> shift_for_lut)))
440457

441458
// TODO this just overrides the previous line. Should we erase the previous line?
442459
when(im2col_en === false.B) {
443460
io.srams.read(i).req.bits.addr := MuxCase(a_address.sp_row(),
444461
Seq(read_b -> b_address.sp_row(),
445-
read_d -> d_address.sp_row()))
462+
read_d -> (d_address.sp_row() >> shift_for_lut)))
446463
}
447464
} else {
448465
io.srams.read(i).req.valid := false.B
@@ -458,6 +475,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
458475
val read_a_from_acc = a_valid && a_read_from_acc && dataABankAcc === i.U && start_inputting_a && !multiply_garbage && a_row_is_not_all_zeros && !(im2col_wire&&im2col_en)
459476
val read_b_from_acc = b_valid && b_read_from_acc && dataBBankAcc === i.U && start_inputting_b && !accumulate_zeros && b_row_is_not_all_zeros //&& !im2col_wire
460477
val read_d_from_acc = d_valid && d_read_from_acc && dataDBankAcc === i.U && start_inputting_d && !preload_zeros && d_row_is_not_all_zeros //&& !im2col_wire
478+
// we do not support LUT and accumulator read for D matrices for now
479+
assert(!(compute_with_lut && read_d_from_acc))
461480

462481
Seq((read_a_from_acc, a_ready), (read_b_from_acc, b_ready), (read_d_from_acc, d_ready)).foreach { case (rd, r) =>
463482
when(rd && !io.acc.read_req(i).ready) {
@@ -559,6 +578,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
559578
if (dataflow == Dataflow.BOTH) {
560579
current_dataflow := config_ex_rs1.dataflow
561580
}
581+
// use dequantization lut path
582+
compute_with_lut := config_ex_rs1.use_lut
562583
}
563584

564585
a_addr_stride := config_ex_rs1.a_stride // TODO this needs to be kept in sync with ROB.scala
@@ -581,6 +602,39 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
581602
cmd.pop := 1.U
582603
}
583604

605+
.elsewhen(DoLutPreload1) {
606+
val preload_lut1 = rs1s(0).asTypeOf(new PreloadLutRs1())
607+
val preload_lut2 = rs2s(0).asTypeOf(new PreloadLutRs2())
608+
609+
quant_lut.io.wr := true.B
610+
611+
val (lut_wrcount, lut_wrdone) = Counter(quant_lut.io.wr, 8)
612+
613+
when(!lut_wrdone) {
614+
quant_lut.io.wraddr(0) := Cat(0.U, lut_wrcount)
615+
quant_lut.io.wrdata(0) := preload_lut1.lut_data(lut_wrcount)
616+
}.otherwise {
617+
io.completed := cmd.bits(0).rob_id
618+
cmd.pop := 1.U
619+
}
620+
}
621+
.elsewhen(DoLutPreload2) {
622+
val preload_lut1 = rs1s(0).asTypeOf(new PreloadLutRs1())
623+
val preload_lut2 = rs2s(0).asTypeOf(new PreloadLutRs2())
624+
625+
quant_lut.io.wr := true.B
626+
627+
val (lut_wrcount, lut_wrdone) = Counter(quant_lut.io.wr, 8)
628+
629+
when(!lut_wrdone) {
630+
quant_lut.io.wraddr(0) := Cat(1.U, lut_wrcount)
631+
quant_lut.io.wrdata(0) := preload_lut2.lut_data(lut_wrcount)
632+
}.otherwise {
633+
io.completed := cmd.bits(0).rob_id
634+
cmd.pop := 1.U
635+
}
636+
}
637+
584638
// Preload
585639
.elsewhen(DoPreloads(0) && cmd.valid(1) && (raw_hazards_are_impossible.B || !raw_hazard_pre)) {
586640
perform_single_preload := true.B
@@ -835,7 +889,44 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
835889

836890
val dataA = VecInit(dataA_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.a_unpadded_cols, d, inputType.zero)})
837891
val dataB = VecInit(dataB_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.b_unpadded_cols, d, inputType.zero)})
838-
val dataD = VecInit(dataD_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.d_unpadded_cols, d, inputType.zero)})
892+
893+
val dataD_unpadded_dequant = Wire(Vec(8, UInt(16.W)))
894+
dontTouch(dataD_unpadded_dequant)
895+
896+
val lut_d_to_mesh = cntl.d_fire && dataD_valid && cntl_valid
897+
val (dcount, ddone) = Counter(lut_d_to_mesh, 4)
898+
899+
when(compute_with_lut) {
900+
when(dcount === 0.U) {
901+
for (i<-0 until 8) {
902+
quant_lut.io.rdaddr(i) := (dataD_unpadded(31,0).asTypeOf(Vec(8,UInt(4.W))))(i)
903+
dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
904+
}
905+
}.elsewhen(dcount === 1.U) {
906+
for (i<-0 until 8) {
907+
quant_lut.io.rdaddr(i) := (dataD_unpadded(63,31).asTypeOf(Vec(8,UInt(4.W))))(i)
908+
dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
909+
}
910+
}.elsewhen(dcount === 2.U) {
911+
for (i<-0 until 8) {
912+
quant_lut.io.rdaddr(i) := (dataD_unpadded(95,64).asTypeOf(Vec(8,UInt(4.W))))(i)
913+
dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
914+
}
915+
}.elsewhen(dcount === 3.U) {
916+
for (i<-0 until 8) {
917+
quant_lut.io.rdaddr(i) := (dataD_unpadded(127,96).asTypeOf(Vec(8,UInt(4.W))))(i)
918+
dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
919+
}
920+
}.otherwise {
921+
for (i<-0 until 8) {
922+
dataD_unpadded_dequant(i) := 0.U
923+
}
924+
}
925+
}.otherwise{
926+
dataD_unpadded_dequant := dataD_unpadded.asTypeOf(dataD_unpadded_dequant)
927+
}
928+
929+
val dataD = VecInit(dataD_unpadded_dequant.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.d_unpadded_cols, d, inputType.zero)})
839930

840931
// Pop responses off the scratchpad io ports
841932
when (mesh_cntl_signals_q.io.deq.fire) {

src/main/scala/gemmini/GemminiISA.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ object GemminiISA {
3434

3535
val CLKGATE_EN = 22.U
3636

37+
val PRELOAD_LUT1 = 23.U
38+
val PRELOAD_LUT2 = 24.U
39+
3740
// rs1[2:0] values
3841
val CONFIG_EX = 0.U
3942
val CONFIG_LOAD = 1.U
@@ -178,7 +181,8 @@ object GemminiISA {
178181
val CONFIG_EX_RS1_CMD_TYPE_WIDTH = 2
179182
val CONFIG_EX_RS1_DATAFLOW_WIDTH = 1
180183
val CONFIG_EX_RS1_ACTIVATION_WIDTH = 2
181-
val CONFIG_EX_RS1_SPACER0_WIDTH = (7 - 2 - 1 - 2)
184+
val CONFIG_EX_RS1_USE_LUT = 1
185+
val CONFIG_EX_RS1_SPACER0_WIDTH = (7 - 2 - 1 - 2 - 1)
182186
val CONFIG_EX_RS1_SET_ONLY_STRIDES_WIDTH = 1
183187
val CONFIG_EX_RS1_A_TRANSPOSE_WIDTH = 1
184188
val CONFIG_EX_RS1_B_TRANSPOSE_WIDTH = 1
@@ -195,6 +199,7 @@ object GemminiISA {
195199
val a_transpose = UInt(CONFIG_EX_RS1_A_TRANSPOSE_WIDTH.W)
196200
val set_only_strides = UInt(CONFIG_EX_RS1_SET_ONLY_STRIDES_WIDTH.W)
197201
val _spacer0 = UInt(CONFIG_EX_RS1_SPACER0_WIDTH.W)
202+
val use_lut = UInt(CONFIG_EX_RS1_USE_LUT.W)
198203
val activation = UInt(CONFIG_EX_RS1_ACTIVATION_WIDTH.W)
199204
val dataflow = UInt(CONFIG_EX_RS1_DATAFLOW_WIDTH.W)
200205
val cmd_type = UInt(CONFIG_EX_RS1_CMD_TYPE_WIDTH.W)
@@ -235,5 +240,13 @@ object GemminiISA {
235240
val _spacer0 = UInt((COMPUTED_RS_ADDR_WIDTH - local_addr_t.getWidth).W)
236241
val local_addr = local_addr_t.cloneType
237242
}
243+
244+
class PreloadLutRs1(lut_t_bits: Int = 16) extends Bundle {
245+
val lut_data = Vec(4,UInt(lut_t_bits.W))
246+
}
247+
248+
class PreloadLutRs2(lut_t_bits: Int = 16) extends Bundle {
249+
val lut_data = Vec(4,UInt(lut_t_bits.W))
250+
}
238251
}
239252

src/main/scala/gemmini/QuantLut.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package gemmini
2+
3+
import chisel3._
4+
import chisel3.util._
5+
import chisel3.experimental._
6+
import org.chipsalliance.cde.config.Parameters
7+
8+
import scala.math.{pow}
9+
10+
class QuantLut (wPorts: Int, rPorts: Int, lutaddr: Int, lutdata: Int)(implicit p: Parameters) extends Module {
11+
val io = IO(new Bundle{
12+
val rdaddr = Input(Vec(rPorts, UInt(lutaddr.W)))
13+
val wraddr = Input(Vec(wPorts, UInt(lutaddr.W)))
14+
val rddata = Output(Vec(rPorts, UInt(lutdata.W)))
15+
val wrdata = Input(Vec(wPorts, UInt(lutdata.W)))
16+
val wr = Input(Bool())
17+
})
18+
19+
val lut = RegInit(VecInit(Seq.fill(pow(2,lutaddr).toInt) (0.U(lutdata.W))))
20+
21+
dontTouch(lut)
22+
23+
for (i <- 0 until rPorts) {
24+
io.rddata(i) := lut(io.rdaddr(i))
25+
}
26+
27+
when(io.wr) {
28+
for (j <- 0 until wPorts) {
29+
lut(io.wraddr(j)) := io.wrdata(j)
30+
}
31+
}
32+
}

src/main/scala/gemmini/ReservationStation.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G
191191
new_entry.issued := false.B
192192
new_entry.cmd := io.alloc.bits
193193

194-
new_entry.is_config := funct === CONFIG_CMD
194+
new_entry.is_config := funct === CONFIG_CMD || funct === PRELOAD_LUT1 || funct === PRELOAD_LUT2
195195

196196
val op1 = Wire(UDValid(new OpT))
197197
op1.valid := false.B
@@ -293,7 +293,7 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G
293293
}
294294

295295
val is_load = funct === LOAD_CMD || funct === LOAD2_CMD || funct === LOAD3_CMD || (funct === CONFIG_CMD && config_cmd_type === CONFIG_LOAD)
296-
val is_ex = funct === PRELOAD_CMD || funct_is_compute || (funct === CONFIG_CMD && config_cmd_type === CONFIG_EX)
296+
val is_ex = funct === PRELOAD_CMD || funct_is_compute || (funct === CONFIG_CMD && config_cmd_type === CONFIG_EX) || funct === PRELOAD_LUT1 || funct === PRELOAD_LUT2
297297
val is_store = funct === STORE_CMD || (funct === CONFIG_CMD && (config_cmd_type === CONFIG_STORE || config_cmd_type === CONFIG_NORM))
298298
val is_norm = funct === CONFIG_CMD && config_cmd_type === CONFIG_NORM // normalization commands are a subset of store commands, so they still go in the store queue
299299

0 commit comments

Comments
 (0)