Skip to content

Commit 26a5b13

Browse files
committed
wip lut interface updates
1 parent 7ea2330 commit 26a5b13

File tree

4 files changed

+20
-15
lines changed

4 files changed

+20
-15
lines changed

src/main/scala/gemmini/Controller.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,9 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
214214
val scale_mem_write_act = Flipped(Decoupled(new ScalingFactorWriteReq(s)))
215215
val requant_in_gpu = Flipped(Decoupled(new RequantizerInBundle(q.numGPUInputLanes, q.inputBits)))
216216
val requant_out = Decoupled(new RequantizerOutBundle(q.numOutputLanes, q.maxOutputBits))
217-
val lut0 = Flipped(Decoupled(new QuantLutWriteBundle(l)))
218-
val lut1 = Flipped(Decoupled(new QuantLutWriteBundle(l)))
219-
val lut2 = Flipped(Decoupled(new QuantLutWriteBundle(l)))
217+
val lut0 = Flipped(Decoupled(new QuantLutWriteBundle(l(0))))
218+
val lut1 = Flipped(Decoupled(new QuantLutWriteBundle(l(1))))
219+
val lut2 = Flipped(Decoupled(new QuantLutWriteBundle(l(2))))
220220
val scale_factor_out = Decoupled(new ScalingFactorWriteReq(s.ScaleMemWriteAddrWidth, s.ScaleMemWriteAddrWidth))
221221
})
222222

src/main/scala/gemmini/MxConfigFragments.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,17 @@ case class GemminiRequantizerConfig(
3737
)
3838

3939
case class GemminiLUTConfig(
40-
numBits: Int = 96,
41-
numEntries: Int = 16,
42-
numTables: Int = 3,
40+
numBits: Seq[Int] = Seq(96, 96, 96),
41+
numEntries: Seq[Int] = Seq(16, 16, 32),
4342
rdataWidth: Int = 6,
4443
raddrWidth: Int = 4,
45-
)
44+
) {
45+
def apply(table: Int) = {
46+
(numEntries(table), numBits(table))
47+
}
48+
49+
def numTables = numBits.length
50+
}
4651

4752
object RequantizerDataType extends ChiselEnum {
4853
val FP8, FP6, FP4 = Value
@@ -86,7 +91,7 @@ class RequantizerOutBundle(numLanes: Int, dataWidth: Int = 8) extends Bundle {
8691

8792
class QuantLutWriteBundle(numEntries: Int, numBits: Int) extends Bundle {
8893
val data = Vec(numEntries, UInt(numBits.W))
89-
def this(config: GemminiLUTConfig) = {
90-
this(config.numEntries, config.numBits)
94+
def this(config: (Int, Int)) = {
95+
this(config._1, config._2)
9196
}
9297
}

src/main/scala/gemmini/MxRequantizer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ class MxRequantizerIO(
5959
val requant_data_in_gpu = Flipped(Decoupled(new RequantizerInBundle(config.numGPUInputLanes, inputdataWidth)))
6060
val scaleMem_write = Decoupled(new ScalingFactorWriteReq(scaleMem_addr_width, scaleMem_data_width))
6161
val requant_data_out = Decoupled(new RequantizerOutBundle(outputnumLanes))
62-
val lut0_write = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig)))
63-
val lut1_write = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig)))
64-
val lut2_write = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig)))
62+
val lut0_write = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig(0))))
63+
val lut1_write = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig(1))))
64+
val lut2_write = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig(2))))
6565
val spad_projected_data = Vec(sp_banks, Flipped(new ScratchpadReadIO(sp_bank_entries, sp_width_projected)))
6666
val spad_deprojected_data = Vec(sp_banks, new ScratchpadReadIO(sp_bank_entries, sp_width))
6767
val fp8_mode = Input(Bool()) // true for 64-lane mode, false for 16-lane mode

src/main/scala/gemmini/QuantLut.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ class QuantLutIO(
1616
sp_width_projected: Int,
1717
iterator_bitwidth: Int,
1818
) extends Bundle {
19-
val lut_write_weight = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig))) //input
20-
val lut_write_act_in = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig))) //input
21-
val lut_write_act_out = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig))) //input
19+
val lut_write_weight = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig(0)))) //input
20+
val lut_write_act_in = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig(1)))) //input
21+
val lut_write_act_out = Flipped(Decoupled(new QuantLutWriteBundle(lutConfig(2)))) //input
2222
val quant_fp6 = Flipped(Valid(Vec(outputnumLanes, UInt(lutConfig.rdataWidth.W)))) //input
2323
val projected_data = Valid(Vec(outputnumLanes, UInt(lutConfig.raddrWidth.W))) //output
2424
val spad_projected_data = Vec(sp_banks, new ScratchpadReadIO(sp_bank_entries, sp_width_projected))

0 commit comments

Comments
 (0)