Skip to content

Commit 2b4fcbd

Browse files
committed
LUT for squeezellm integrated
2 parents 7676382 + 97798dc commit 2b4fcbd

File tree

8 files changed

+179
-9
lines changed

8 files changed

+179
-9
lines changed

src/main/scala/gemmini/Configs.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ object GemminiConfigs {
3333
tileColumns = 1,
3434
meshRows = 16,
3535
meshColumns = 16,
36+
quantWidth = 4,
3637

3738
// Spatial array PE options
3839
dataflow = Dataflow.BOTH,
@@ -177,6 +178,7 @@ object GemminiConfigs {
177178
tileColumns = defaultConfig.tileColumns,
178179
meshRows = defaultConfig.meshRows,
179180
meshColumns = defaultConfig.meshColumns,
181+
quantWidth = defaultConfig.quantWidth,
180182
dataflow = defaultConfig.dataflow,
181183
sp_capacity = CapacityInKilobytes(128),
182184
acc_capacity = CapacityInKilobytes(128),
@@ -245,6 +247,8 @@ object GemminiConfigs {
245247

246248
val leanPrintfConfig = defaultConfig.copy(dataflow=Dataflow.WS, max_in_flight_mem_reqs = 64, acc_read_full_width = false, ex_read_from_acc = false, ex_write_to_spad = false, hardcode_d_to_garbage_addr = true, use_firesim_simulation_counters=true)
247249

250+
val lutLeanConfig = defaultConfig.copy(dataflow=Dataflow.WS, inputType=SInt(16.W), meshRows=8, meshColumns=8, max_in_flight_mem_reqs = 64, acc_read_full_width = false, ex_read_from_acc = false, ex_write_to_spad = false, hardcode_d_to_garbage_addr = true)
251+
248252
}
249253

250254
/**
@@ -374,3 +378,18 @@ class DualGemminiConfig extends Config((site, here, up) => {
374378
up(BuildRoCC) ++ Seq(int_fn, fp_fn)
375379
}
376380
})
381+
382+
/**
383+
* Mixin which sets the default lut lean parameters for a systolic array accelerator.
384+
*/
385+
class LutLeanGemminiConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
386+
gemminiConfig: GemminiArrayConfig[T,U,V] = GemminiConfigs.lutLeanConfig
387+
) extends Config((site, here, up) => {
388+
case BuildRoCC => up(BuildRoCC) ++ Seq(
389+
(p: Parameters) => {
390+
implicit val q = p
391+
val gemmini = LazyModule(new Gemmini(gemminiConfig))
392+
gemmini
393+
}
394+
)
395+
})

src/main/scala/gemmini/ConfigsFP.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ object GemminiFPConfigs {
1919
tileColumns = 1,
2020
meshRows = 4,
2121
meshColumns = 4,
22+
quantWidth = 4,
2223

2324
ld_queue_length = 8,
2425
st_queue_length = 2,
@@ -113,6 +114,13 @@ object GemminiFPConfigs {
113114
mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")),
114115
)
115116

117+
//FP16 Half Precision Configuration for LUT with 8x8 array
118+
val LutFP16DefaultConfig = defaultFPConfig.copy(inputType = Float(5, 11), weightType = Float(5, 11), accType = Float(8, 24), spatialArrayInputType = Float(5, 11), spatialArrayWeightType = Float(5, 11), spatialArrayOutputType = Float(5, 11),
119+
meshRows = 8, meshColumns = 8,
120+
tile_latency = 2,
121+
mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))")),
122+
mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))")),
123+
)
116124

117125
val chipFP32Config = FP32DefaultConfig.copy(sp_capacity=CapacityInKilobytes(32), acc_capacity=CapacityInKilobytes(8), dataflow=Dataflow.WS,
118126
acc_scale_args = Some(ScaleArguments((t: Float, u: Float) => {t}, 1, Float(8, 24), -1, identity = "1.0",
@@ -141,6 +149,7 @@ object GemminiFPConfigs {
141149
meshRows = 16,
142150
meshColumns = 16,
143151
accType = Float(5, 11),
152+
quantWidth = 4,
144153
spatialArrayInputType = Float(5, 11, isRecoded = true),
145154
spatialArrayWeightType = Float(5, 11, isRecoded = true),
146155
spatialArrayOutputType = Float(5, 11, isRecoded = true),
@@ -255,3 +264,14 @@ class GemminiBF16Default8Config extends Config((site, here, up) => {
255264
)
256265
})
257266

267+
//===========FP16 LUT-based Default Config=========
268+
class GemminiLutFP16DefaultConfig extends Config((site, here, up) => {
269+
case BuildRoCC => Seq(
270+
(p: Parameters) => {
271+
implicit val q = p
272+
implicit val v = implicitly[ValName]
273+
LazyModule(new Gemmini(GemminiFPConfigs.LutFP16DefaultConfig))
274+
}
275+
)
276+
})
277+

src/main/scala/gemmini/ExecuteController.scala

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,19 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
6161
}
6262
}
6363

64+
val quant_lut = Module(new QuantLut(1, meshColumns, quantWidth, weightType.getWidth))
65+
quant_lut.io.wraddr := DontCare
66+
quant_lut.io.rdaddr := DontCare
67+
quant_lut.io.wrdata := DontCare
68+
quant_lut.io.wr := false.B
69+
70+
val compute_with_lut = RegInit(false.B)
71+
72+
val shift_for_lut = RegInit(0.U(log2Up(quantWidth).W))
73+
when(compute_with_lut) {
74+
shift_for_lut := log2Up(quantWidth).U
75+
}.otherwise {shift_for_lut := 0.U}
76+
6477
val unrolled_cmd = TransposePreloadUnroller(io.cmd, config, io.counter)
6578

6679
val cmd_q_heads = 3
@@ -83,6 +96,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
8396
val DoConfig = functs(0) === CONFIG_CMD
8497
val DoComputes = functs.map(f => f === COMPUTE_AND_FLIP_CMD || f === COMPUTE_AND_STAY_CMD)
8598
val DoPreloads = functs.map(_ === PRELOAD_CMD)
99+
val DoLutPreload1 = functs(0) === PRELOAD_LUT1
100+
val DoLutPreload2 = functs(0) === PRELOAD_LUT2
86101

87102
val preload_cmd_place = Mux(DoPreloads(0), 0.U, 1.U)
88103
// val a_address_place = Mux(current_dataflow === Dataflow.WS.id.U, 0.U, Mux(preload_cmd_place === 0.U, 1.U, 2.U))
@@ -436,13 +451,13 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
436451
io.srams.read(i).req.bits.fromDMA := false.B
437452
io.srams.read(i).req.bits.addr := MuxCase(a_address_rs1.sp_row() + a_fire_counter,
438453
Seq(read_b -> (b_address_rs2.sp_row() + b_fire_counter),
439-
read_d -> (d_address_rs1.sp_row() + block_size.U - 1.U - d_fire_counter_mulpre)))
454+
read_d -> ((d_address_rs1.sp_row() + block_size.U - 1.U - d_fire_counter_mulpre) >> shift_for_lut)))
440455

441456
// TODO this just overrides the previous line. Should we erase the previous line?
442457
when(im2col_en === false.B) {
443458
io.srams.read(i).req.bits.addr := MuxCase(a_address.sp_row(),
444459
Seq(read_b -> b_address.sp_row(),
445-
read_d -> d_address.sp_row()))
460+
read_d -> (d_address.sp_row() >> shift_for_lut)))
446461
}
447462
} else {
448463
io.srams.read(i).req.valid := false.B
@@ -458,6 +473,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
458473
val read_a_from_acc = a_valid && a_read_from_acc && dataABankAcc === i.U && start_inputting_a && !multiply_garbage && a_row_is_not_all_zeros && !(im2col_wire&&im2col_en)
459474
val read_b_from_acc = b_valid && b_read_from_acc && dataBBankAcc === i.U && start_inputting_b && !accumulate_zeros && b_row_is_not_all_zeros //&& !im2col_wire
460475
val read_d_from_acc = d_valid && d_read_from_acc && dataDBankAcc === i.U && start_inputting_d && !preload_zeros && d_row_is_not_all_zeros //&& !im2col_wire
476+
// we do not support LUT and accumulator read for D matrices for now
477+
assert(!(compute_with_lut && read_d_from_acc))
461478

462479
Seq((read_a_from_acc, a_ready), (read_b_from_acc, b_ready), (read_d_from_acc, d_ready)).foreach { case (rd, r) =>
463480
when(rd && !io.acc.read_req(i).ready) {
@@ -559,6 +576,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
559576
if (dataflow == Dataflow.BOTH) {
560577
current_dataflow := config_ex_rs1.dataflow
561578
}
579+
// use dequantization lut path
580+
compute_with_lut := config_ex_rs1.use_lut
562581
}
563582

564583
a_addr_stride := config_ex_rs1.a_stride // TODO this needs to be kept in sync with ROB.scala
@@ -581,6 +600,39 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
581600
cmd.pop := 1.U
582601
}
583602

603+
.elsewhen(DoLutPreload1) {
604+
val preload_lut1 = rs1s(0).asTypeOf(new PreloadLutRs1())
605+
val preload_lut2 = rs2s(0).asTypeOf(new PreloadLutRs2())
606+
607+
quant_lut.io.wr := true.B
608+
609+
val (lut_wrcount, lut_wrdone) = Counter(quant_lut.io.wr, 8)
610+
611+
when(!lut_wrdone) {
612+
quant_lut.io.wraddr(0) := Cat(0.U, lut_wrcount)
613+
quant_lut.io.wrdata(0) := preload_lut1.lut_data(lut_wrcount)
614+
}.otherwise {
615+
io.completed := cmd.bits(0).rob_id
616+
cmd.pop := 1.U
617+
}
618+
}
619+
.elsewhen(DoLutPreload2) {
620+
val preload_lut1 = rs1s(0).asTypeOf(new PreloadLutRs1())
621+
val preload_lut2 = rs2s(0).asTypeOf(new PreloadLutRs2())
622+
623+
quant_lut.io.wr := true.B
624+
625+
val (lut_wrcount, lut_wrdone) = Counter(quant_lut.io.wr, 8)
626+
627+
when(!lut_wrdone) {
628+
quant_lut.io.wraddr(0) := Cat(1.U, lut_wrcount)
629+
quant_lut.io.wrdata(0) := preload_lut2.lut_data(lut_wrcount)
630+
}.otherwise {
631+
io.completed := cmd.bits(0).rob_id
632+
cmd.pop := 1.U
633+
}
634+
}
635+
584636
// Preload
585637
.elsewhen(DoPreloads(0) && cmd.valid(1) && (raw_hazards_are_impossible.B || !raw_hazard_pre)) {
586638
perform_single_preload := true.B
@@ -834,8 +886,43 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
834886
val dataD_unpadded = MuxCase(readData(cntl.d_bank), Seq(cntl.preload_zeros -> 0.U, cntl.d_read_from_acc -> accReadData(cntl.d_bank_acc)))
835887

836888
val dataA = VecInit(dataA_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.a_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayInputType)))
837-
val dataB = VecInit(dataB_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.b_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
838-
val dataD = VecInit(dataD_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.d_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
889+
val dataB = VecInit(dataB_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.b_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
890+
891+
val dataD_unpadded_dequant = Wire(Vec(meshColumns, UInt((weightType.getWidth).W)))
892+
val lut_d_to_mesh = cntl.d_fire && dataD_valid && cntl_valid
893+
val (dcount, ddone) = Counter(lut_d_to_mesh, (weightType.getWidth/quantWidth))
894+
895+
when(compute_with_lut) {
896+
when(dcount === 0.U) {
897+
for (i<-0 until meshColumns) {
898+
quant_lut.io.rdaddr(i) := (dataD_unpadded((meshColumns*quantWidth-1), 0).asTypeOf(Vec(meshColumns, UInt(quantWidth.W))))(i)
899+
dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
900+
}
901+
}.elsewhen(dcount === 1.U) {
902+
for (i<-0 until meshColumns) {
903+
quant_lut.io.rdaddr(i) := (dataD_unpadded((2*meshColumns*quantWidth-1), (meshColumns*quantWidth-1)).asTypeOf(Vec(meshColumns, UInt(quantWidth.W))))(i)
904+
dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
905+
}
906+
}.elsewhen(dcount === 2.U) {
907+
for (i<-0 until meshColumns) {
908+
quant_lut.io.rdaddr(i) := (dataD_unpadded((3*meshColumns*quantWidth-1), (2*meshColumns*quantWidth-1)).asTypeOf(Vec(meshColumns, UInt(quantWidth.W))))(i)
909+
dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
910+
}
911+
}.elsewhen(dcount === 3.U) {
912+
for (i<-0 until meshColumns) {
913+
quant_lut.io.rdaddr(i) := (dataD_unpadded((4*meshColumns*quantWidth-1), (3*meshColumns*quantWidth-1)).asTypeOf(Vec(meshColumns, UInt(quantWidth.W))))(i)
914+
dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
915+
}
916+
}.otherwise {
917+
for (i<-0 until meshColumns) {
918+
dataD_unpadded_dequant(i) := 0.U
919+
}
920+
}
921+
}.otherwise{
922+
dataD_unpadded_dequant := dataD_unpadded.asTypeOf(dataD_unpadded_dequant)
923+
}
924+
925+
val dataD = VecInit(dataD_unpadded_dequant.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.d_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
839926

840927
// Pop responses off the scratchpad io ports
841928
when (mesh_cntl_signals_q.io.deq.fire) {

src/main/scala/gemmini/GemminiConfigs.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
1818
opcodes: OpcodeSet = OpcodeSet.custom3,
1919

2020
inputType: T,
21-
weightType: T,
21+
weightType: T,
2222
accType: T,
2323
spatialArrayInputType: T,
2424
spatialArrayWeightType: T,
@@ -30,6 +30,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
3030
tileColumns: Int = 1,
3131
meshRows: Int = 16,
3232
meshColumns: Int = 16,
33+
quantWidth: Int = 4,
3334

3435
ld_queue_length: Int = 8,
3536
st_queue_length: Int = 2,

src/main/scala/gemmini/GemminiISA.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ object GemminiISA {
3434

3535
val CLKGATE_EN = 22.U
3636

37+
val PRELOAD_LUT1 = 23.U
38+
val PRELOAD_LUT2 = 24.U
39+
3740
// rs1[2:0] values
3841
val CONFIG_EX = 0.U
3942
val CONFIG_LOAD = 1.U
@@ -178,7 +181,8 @@ object GemminiISA {
178181
val CONFIG_EX_RS1_CMD_TYPE_WIDTH = 2
179182
val CONFIG_EX_RS1_DATAFLOW_WIDTH = 1
180183
val CONFIG_EX_RS1_ACTIVATION_WIDTH = 2
181-
val CONFIG_EX_RS1_SPACER0_WIDTH = (7 - 2 - 1 - 2)
184+
val CONFIG_EX_RS1_USE_LUT = 1
185+
val CONFIG_EX_RS1_SPACER0_WIDTH = (7 - 2 - 1 - 2 - 1)
182186
val CONFIG_EX_RS1_SET_ONLY_STRIDES_WIDTH = 1
183187
val CONFIG_EX_RS1_A_TRANSPOSE_WIDTH = 1
184188
val CONFIG_EX_RS1_B_TRANSPOSE_WIDTH = 1
@@ -195,6 +199,7 @@ object GemminiISA {
195199
val a_transpose = UInt(CONFIG_EX_RS1_A_TRANSPOSE_WIDTH.W)
196200
val set_only_strides = UInt(CONFIG_EX_RS1_SET_ONLY_STRIDES_WIDTH.W)
197201
val _spacer0 = UInt(CONFIG_EX_RS1_SPACER0_WIDTH.W)
202+
val use_lut = UInt(CONFIG_EX_RS1_USE_LUT.W)
198203
val activation = UInt(CONFIG_EX_RS1_ACTIVATION_WIDTH.W)
199204
val dataflow = UInt(CONFIG_EX_RS1_DATAFLOW_WIDTH.W)
200205
val cmd_type = UInt(CONFIG_EX_RS1_CMD_TYPE_WIDTH.W)
@@ -235,5 +240,13 @@ object GemminiISA {
235240
val _spacer0 = UInt((COMPUTED_RS_ADDR_WIDTH - local_addr_t.getWidth).W)
236241
val local_addr = local_addr_t.cloneType
237242
}
243+
244+
class PreloadLutRs1(lut_t_bits: Int = 16) extends Bundle {
245+
val lut_data = Vec(4,UInt(lut_t_bits.W))
246+
}
247+
248+
class PreloadLutRs2(lut_t_bits: Int = 16) extends Bundle {
249+
val lut_data = Vec(4,UInt(lut_t_bits.W))
250+
}
238251
}
239252

src/main/scala/gemmini/QuantLut.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package gemmini
2+
3+
import chisel3._
4+
import chisel3.util._
5+
import chisel3.experimental._
6+
import org.chipsalliance.cde.config.Parameters
7+
8+
import scala.math.{pow}
9+
10+
class QuantLut (wPorts: Int, rPorts: Int, lutaddr: Int, lutdata: Int)(implicit p: Parameters) extends Module {
11+
val io = IO(new Bundle{
12+
val rdaddr = Input(Vec(rPorts, UInt(lutaddr.W)))
13+
val wraddr = Input(Vec(wPorts, UInt(lutaddr.W)))
14+
val rddata = Output(Vec(rPorts, UInt(lutdata.W)))
15+
val wrdata = Input(Vec(wPorts, UInt(lutdata.W)))
16+
val wr = Input(Bool())
17+
})
18+
19+
val lut = RegInit(VecInit(Seq.fill(pow(2,lutaddr).toInt) (0.U(lutdata.W))))
20+
21+
for (i <- 0 until rPorts) {
22+
io.rddata(i) := lut(io.rdaddr(i))
23+
}
24+
25+
when(io.wr) {
26+
for (j <- 0 until wPorts) {
27+
lut(io.wraddr(j)) := io.wrdata(j)
28+
}
29+
}
30+
}

src/main/scala/gemmini/ReservationStation.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G
191191
new_entry.issued := false.B
192192
new_entry.cmd := io.alloc.bits
193193

194-
new_entry.is_config := funct === CONFIG_CMD
194+
new_entry.is_config := funct === CONFIG_CMD || funct === PRELOAD_LUT1 || funct === PRELOAD_LUT2
195195

196196
val op1 = Wire(UDValid(new OpT))
197197
op1.valid := false.B
@@ -293,7 +293,7 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G
293293
}
294294

295295
val is_load = funct === LOAD_CMD || funct === LOAD2_CMD || funct === LOAD3_CMD || (funct === CONFIG_CMD && config_cmd_type === CONFIG_LOAD)
296-
val is_ex = funct === PRELOAD_CMD || funct_is_compute || (funct === CONFIG_CMD && config_cmd_type === CONFIG_EX)
296+
val is_ex = funct === PRELOAD_CMD || funct_is_compute || (funct === CONFIG_CMD && config_cmd_type === CONFIG_EX) || funct === PRELOAD_LUT1 || funct === PRELOAD_LUT2
297297
val is_store = funct === STORE_CMD || (funct === CONFIG_CMD && (config_cmd_type === CONFIG_STORE || config_cmd_type === CONFIG_NORM))
298298
val is_norm = funct === CONFIG_CMD && config_cmd_type === CONFIG_NORM // normalization commands are a subset of store commands, so they still go in the store queue
299299

0 commit comments

Comments
 (0)