Skip to content

Commit

Permalink
LUT for squeezellm integrated
Browse files Browse the repository at this point in the history
  • Loading branch information
vikramjain236 committed Jul 15, 2024
2 parents 7676382 + 97798dc commit 2b4fcbd
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 9 deletions.
19 changes: 19 additions & 0 deletions src/main/scala/gemmini/Configs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ object GemminiConfigs {
tileColumns = 1,
meshRows = 16,
meshColumns = 16,
quantWidth = 4,

// Spatial array PE options
dataflow = Dataflow.BOTH,
Expand Down Expand Up @@ -177,6 +178,7 @@ object GemminiConfigs {
tileColumns = defaultConfig.tileColumns,
meshRows = defaultConfig.meshRows,
meshColumns = defaultConfig.meshColumns,
quantWidth = defaultConfig.quantWidth,
dataflow = defaultConfig.dataflow,
sp_capacity = CapacityInKilobytes(128),
acc_capacity = CapacityInKilobytes(128),
Expand Down Expand Up @@ -245,6 +247,8 @@ object GemminiConfigs {

val leanPrintfConfig = defaultConfig.copy(dataflow=Dataflow.WS, max_in_flight_mem_reqs = 64, acc_read_full_width = false, ex_read_from_acc = false, ex_write_to_spad = false, hardcode_d_to_garbage_addr = true, use_firesim_simulation_counters=true)

val lutLeanConfig = defaultConfig.copy(dataflow=Dataflow.WS, inputType=SInt(16.W), meshRows=8, meshColumns=8, max_in_flight_mem_reqs = 64, acc_read_full_width = false, ex_read_from_acc = false, ex_write_to_spad = false, hardcode_d_to_garbage_addr = true)

}

/**
Expand Down Expand Up @@ -374,3 +378,18 @@ class DualGemminiConfig extends Config((site, here, up) => {
up(BuildRoCC) ++ Seq(int_fn, fp_fn)
}
})

/**
* Mixin which sets the default lut lean parameters for a systolic array accelerator.
*/
class LutLeanGemminiConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
gemminiConfig: GemminiArrayConfig[T,U,V] = GemminiConfigs.lutLeanConfig
) extends Config((site, here, up) => {
case BuildRoCC => up(BuildRoCC) ++ Seq(
(p: Parameters) => {
implicit val q = p
val gemmini = LazyModule(new Gemmini(gemminiConfig))
gemmini
}
)
})
20 changes: 20 additions & 0 deletions src/main/scala/gemmini/ConfigsFP.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ object GemminiFPConfigs {
tileColumns = 1,
meshRows = 4,
meshColumns = 4,
quantWidth = 4,

ld_queue_length = 8,
st_queue_length = 2,
Expand Down Expand Up @@ -113,6 +114,13 @@ object GemminiFPConfigs {
mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")),
)

//FP16 Half Precision Configuration for LUT with 8x8 array
val LutFP16DefaultConfig = defaultFPConfig.copy(inputType = Float(5, 11), weightType = Float(5, 11), accType = Float(8, 24), spatialArrayInputType = Float(5, 11), spatialArrayWeightType = Float(5, 11), spatialArrayOutputType = Float(5, 11),
meshRows = 8, meshColumns = 8,
tile_latency = 2,
mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))")),
mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))")),
)

val chipFP32Config = FP32DefaultConfig.copy(sp_capacity=CapacityInKilobytes(32), acc_capacity=CapacityInKilobytes(8), dataflow=Dataflow.WS,
acc_scale_args = Some(ScaleArguments((t: Float, u: Float) => {t}, 1, Float(8, 24), -1, identity = "1.0",
Expand Down Expand Up @@ -141,6 +149,7 @@ object GemminiFPConfigs {
meshRows = 16,
meshColumns = 16,
accType = Float(5, 11),
quantWidth = 4,
spatialArrayInputType = Float(5, 11, isRecoded = true),
spatialArrayWeightType = Float(5, 11, isRecoded = true),
spatialArrayOutputType = Float(5, 11, isRecoded = true),
Expand Down Expand Up @@ -255,3 +264,14 @@ class GemminiBF16Default8Config extends Config((site, here, up) => {
)
})

//===========FP16 LUT-based Default Config=========
class GemminiLutFP16DefaultConfig extends Config((site, here, up) => {
case BuildRoCC => Seq(
(p: Parameters) => {
implicit val q = p
implicit val v = implicitly[ValName]
LazyModule(new Gemmini(GemminiFPConfigs.LutFP16DefaultConfig))
}
)
})

95 changes: 91 additions & 4 deletions src/main/scala/gemmini/ExecuteController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
}
}

val quant_lut = Module(new QuantLut(1, meshColumns, quantWidth, weightType.getWidth))
quant_lut.io.wraddr := DontCare
quant_lut.io.rdaddr := DontCare
quant_lut.io.wrdata := DontCare
quant_lut.io.wr := false.B

val compute_with_lut = RegInit(false.B)

val shift_for_lut = RegInit(0.U(log2Up(quantWidth).W))
when(compute_with_lut) {
shift_for_lut := log2Up(quantWidth).U
}.otherwise {shift_for_lut := 0.U}

val unrolled_cmd = TransposePreloadUnroller(io.cmd, config, io.counter)

val cmd_q_heads = 3
Expand All @@ -83,6 +96,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
val DoConfig = functs(0) === CONFIG_CMD
val DoComputes = functs.map(f => f === COMPUTE_AND_FLIP_CMD || f === COMPUTE_AND_STAY_CMD)
val DoPreloads = functs.map(_ === PRELOAD_CMD)
val DoLutPreload1 = functs(0) === PRELOAD_LUT1
val DoLutPreload2 = functs(0) === PRELOAD_LUT2

val preload_cmd_place = Mux(DoPreloads(0), 0.U, 1.U)
// val a_address_place = Mux(current_dataflow === Dataflow.WS.id.U, 0.U, Mux(preload_cmd_place === 0.U, 1.U, 2.U))
Expand Down Expand Up @@ -436,13 +451,13 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
io.srams.read(i).req.bits.fromDMA := false.B
io.srams.read(i).req.bits.addr := MuxCase(a_address_rs1.sp_row() + a_fire_counter,
Seq(read_b -> (b_address_rs2.sp_row() + b_fire_counter),
read_d -> (d_address_rs1.sp_row() + block_size.U - 1.U - d_fire_counter_mulpre)))
read_d -> ((d_address_rs1.sp_row() + block_size.U - 1.U - d_fire_counter_mulpre) >> shift_for_lut)))

// TODO this just overrides the previous line. Should we erase the previous line?
when(im2col_en === false.B) {
io.srams.read(i).req.bits.addr := MuxCase(a_address.sp_row(),
Seq(read_b -> b_address.sp_row(),
read_d -> d_address.sp_row()))
read_d -> (d_address.sp_row() >> shift_for_lut)))
}
} else {
io.srams.read(i).req.valid := false.B
Expand All @@ -458,6 +473,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
val read_a_from_acc = a_valid && a_read_from_acc && dataABankAcc === i.U && start_inputting_a && !multiply_garbage && a_row_is_not_all_zeros && !(im2col_wire&&im2col_en)
val read_b_from_acc = b_valid && b_read_from_acc && dataBBankAcc === i.U && start_inputting_b && !accumulate_zeros && b_row_is_not_all_zeros //&& !im2col_wire
val read_d_from_acc = d_valid && d_read_from_acc && dataDBankAcc === i.U && start_inputting_d && !preload_zeros && d_row_is_not_all_zeros //&& !im2col_wire
// we do not support LUT and accumulator read for D matrices for now
assert(!(compute_with_lut && read_d_from_acc))

Seq((read_a_from_acc, a_ready), (read_b_from_acc, b_ready), (read_d_from_acc, d_ready)).foreach { case (rd, r) =>
when(rd && !io.acc.read_req(i).ready) {
Expand Down Expand Up @@ -559,6 +576,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
if (dataflow == Dataflow.BOTH) {
current_dataflow := config_ex_rs1.dataflow
}
// use dequantization lut path
compute_with_lut := config_ex_rs1.use_lut
}

a_addr_stride := config_ex_rs1.a_stride // TODO this needs to be kept in sync with ROB.scala
Expand All @@ -581,6 +600,39 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
cmd.pop := 1.U
}

.elsewhen(DoLutPreload1) {
val preload_lut1 = rs1s(0).asTypeOf(new PreloadLutRs1())
val preload_lut2 = rs2s(0).asTypeOf(new PreloadLutRs2())

quant_lut.io.wr := true.B

val (lut_wrcount, lut_wrdone) = Counter(quant_lut.io.wr, 8)

when(!lut_wrdone) {
quant_lut.io.wraddr(0) := Cat(0.U, lut_wrcount)
quant_lut.io.wrdata(0) := preload_lut1.lut_data(lut_wrcount)
}.otherwise {
io.completed := cmd.bits(0).rob_id
cmd.pop := 1.U
}
}
.elsewhen(DoLutPreload2) {
val preload_lut1 = rs1s(0).asTypeOf(new PreloadLutRs1())
val preload_lut2 = rs2s(0).asTypeOf(new PreloadLutRs2())

quant_lut.io.wr := true.B

val (lut_wrcount, lut_wrdone) = Counter(quant_lut.io.wr, 8)

when(!lut_wrdone) {
quant_lut.io.wraddr(0) := Cat(1.U, lut_wrcount)
quant_lut.io.wrdata(0) := preload_lut2.lut_data(lut_wrcount)
}.otherwise {
io.completed := cmd.bits(0).rob_id
cmd.pop := 1.U
}
}

// Preload
.elsewhen(DoPreloads(0) && cmd.valid(1) && (raw_hazards_are_impossible.B || !raw_hazard_pre)) {
perform_single_preload := true.B
Expand Down Expand Up @@ -834,8 +886,43 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
val dataD_unpadded = MuxCase(readData(cntl.d_bank), Seq(cntl.preload_zeros -> 0.U, cntl.d_read_from_acc -> accReadData(cntl.d_bank_acc)))

val dataA = VecInit(dataA_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.a_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayInputType)))
val dataB = VecInit(dataB_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.b_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
val dataD = VecInit(dataD_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.d_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))
val dataB = VecInit(dataB_unpadded.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.b_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))

val dataD_unpadded_dequant = Wire(Vec(meshColumns, UInt((weightType.getWidth).W)))
val lut_d_to_mesh = cntl.d_fire && dataD_valid && cntl_valid
val (dcount, ddone) = Counter(lut_d_to_mesh, (weightType.getWidth/quantWidth))

when(compute_with_lut) {
when(dcount === 0.U) {
for (i<-0 until meshColumns) {
quant_lut.io.rdaddr(i) := (dataD_unpadded((meshColumns*quantWidth-1), 0).asTypeOf(Vec(meshColumns, UInt(quantWidth.W))))(i)
dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
}
}.elsewhen(dcount === 1.U) {
for (i<-0 until meshColumns) {
quant_lut.io.rdaddr(i) := (dataD_unpadded((2*meshColumns*quantWidth-1), (meshColumns*quantWidth-1)).asTypeOf(Vec(meshColumns, UInt(quantWidth.W))))(i)
dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
}
}.elsewhen(dcount === 2.U) {
for (i<-0 until meshColumns) {
quant_lut.io.rdaddr(i) := (dataD_unpadded((3*meshColumns*quantWidth-1), (2*meshColumns*quantWidth-1)).asTypeOf(Vec(meshColumns, UInt(quantWidth.W))))(i)
dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
}
}.elsewhen(dcount === 3.U) {
for (i<-0 until meshColumns) {
quant_lut.io.rdaddr(i) := (dataD_unpadded((4*meshColumns*quantWidth-1), (3*meshColumns*quantWidth-1)).asTypeOf(Vec(meshColumns, UInt(quantWidth.W))))(i)
dataD_unpadded_dequant(i) := quant_lut.io.rddata(i)
}
}.otherwise {
for (i<-0 until meshColumns) {
dataD_unpadded_dequant(i) := 0.U
}
}
}.otherwise{
dataD_unpadded_dequant := dataD_unpadded.asTypeOf(dataD_unpadded_dequant)
}

val dataD = VecInit(dataD_unpadded_dequant.asTypeOf(Vec(block_size, inputType)).zipWithIndex.map { case (d, i) => Mux(i.U < cntl.d_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)))

// Pop responses off the scratchpad io ports
when (mesh_cntl_signals_q.io.deq.fire) {
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/gemmini/GemminiConfigs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
opcodes: OpcodeSet = OpcodeSet.custom3,

inputType: T,
weightType: T,
weightType: T,
accType: T,
spatialArrayInputType: T,
spatialArrayWeightType: T,
Expand All @@ -30,6 +30,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
tileColumns: Int = 1,
meshRows: Int = 16,
meshColumns: Int = 16,
quantWidth: Int = 4,

ld_queue_length: Int = 8,
st_queue_length: Int = 2,
Expand Down
15 changes: 14 additions & 1 deletion src/main/scala/gemmini/GemminiISA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ object GemminiISA {

val CLKGATE_EN = 22.U

val PRELOAD_LUT1 = 23.U
val PRELOAD_LUT2 = 24.U

// rs1[2:0] values
val CONFIG_EX = 0.U
val CONFIG_LOAD = 1.U
Expand Down Expand Up @@ -178,7 +181,8 @@ object GemminiISA {
val CONFIG_EX_RS1_CMD_TYPE_WIDTH = 2
val CONFIG_EX_RS1_DATAFLOW_WIDTH = 1
val CONFIG_EX_RS1_ACTIVATION_WIDTH = 2
val CONFIG_EX_RS1_SPACER0_WIDTH = (7 - 2 - 1 - 2)
val CONFIG_EX_RS1_USE_LUT = 1
val CONFIG_EX_RS1_SPACER0_WIDTH = (7 - 2 - 1 - 2 - 1)
val CONFIG_EX_RS1_SET_ONLY_STRIDES_WIDTH = 1
val CONFIG_EX_RS1_A_TRANSPOSE_WIDTH = 1
val CONFIG_EX_RS1_B_TRANSPOSE_WIDTH = 1
Expand All @@ -195,6 +199,7 @@ object GemminiISA {
val a_transpose = UInt(CONFIG_EX_RS1_A_TRANSPOSE_WIDTH.W)
val set_only_strides = UInt(CONFIG_EX_RS1_SET_ONLY_STRIDES_WIDTH.W)
val _spacer0 = UInt(CONFIG_EX_RS1_SPACER0_WIDTH.W)
val use_lut = UInt(CONFIG_EX_RS1_USE_LUT.W)
val activation = UInt(CONFIG_EX_RS1_ACTIVATION_WIDTH.W)
val dataflow = UInt(CONFIG_EX_RS1_DATAFLOW_WIDTH.W)
val cmd_type = UInt(CONFIG_EX_RS1_CMD_TYPE_WIDTH.W)
Expand Down Expand Up @@ -235,5 +240,13 @@ object GemminiISA {
val _spacer0 = UInt((COMPUTED_RS_ADDR_WIDTH - local_addr_t.getWidth).W)
val local_addr = local_addr_t.cloneType
}

class PreloadLutRs1(lut_t_bits: Int = 16) extends Bundle {
val lut_data = Vec(4,UInt(lut_t_bits.W))
}

class PreloadLutRs2(lut_t_bits: Int = 16) extends Bundle {
val lut_data = Vec(4,UInt(lut_t_bits.W))
}
}

30 changes: 30 additions & 0 deletions src/main/scala/gemmini/QuantLut.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package gemmini

import chisel3._
import chisel3.util._
import chisel3.experimental._
import org.chipsalliance.cde.config.Parameters

import scala.math.{pow}

class QuantLut (wPorts: Int, rPorts: Int, lutaddr: Int, lutdata: Int)(implicit p: Parameters) extends Module {
val io = IO(new Bundle{
val rdaddr = Input(Vec(rPorts, UInt(lutaddr.W)))
val wraddr = Input(Vec(wPorts, UInt(lutaddr.W)))
val rddata = Output(Vec(rPorts, UInt(lutdata.W)))
val wrdata = Input(Vec(wPorts, UInt(lutdata.W)))
val wr = Input(Bool())
})

val lut = RegInit(VecInit(Seq.fill(pow(2,lutaddr).toInt) (0.U(lutdata.W))))

for (i <- 0 until rPorts) {
io.rddata(i) := lut(io.rdaddr(i))
}

when(io.wr) {
for (j <- 0 until wPorts) {
lut(io.wraddr(j)) := io.wrdata(j)
}
}
}
4 changes: 2 additions & 2 deletions src/main/scala/gemmini/ReservationStation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G
new_entry.issued := false.B
new_entry.cmd := io.alloc.bits

new_entry.is_config := funct === CONFIG_CMD
new_entry.is_config := funct === CONFIG_CMD || funct === PRELOAD_LUT1 || funct === PRELOAD_LUT2

val op1 = Wire(UDValid(new OpT))
op1.valid := false.B
Expand Down Expand Up @@ -293,7 +293,7 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G
}

val is_load = funct === LOAD_CMD || funct === LOAD2_CMD || funct === LOAD3_CMD || (funct === CONFIG_CMD && config_cmd_type === CONFIG_LOAD)
val is_ex = funct === PRELOAD_CMD || funct_is_compute || (funct === CONFIG_CMD && config_cmd_type === CONFIG_EX)
val is_ex = funct === PRELOAD_CMD || funct_is_compute || (funct === CONFIG_CMD && config_cmd_type === CONFIG_EX) || funct === PRELOAD_LUT1 || funct === PRELOAD_LUT2
val is_store = funct === STORE_CMD || (funct === CONFIG_CMD && (config_cmd_type === CONFIG_STORE || config_cmd_type === CONFIG_NORM))
val is_norm = funct === CONFIG_CMD && config_cmd_type === CONFIG_NORM // normalization commands are a subset of store commands, so they still go in the store queue

Expand Down

0 comments on commit 2b4fcbd

Please sign in to comment.