Skip to content

Commit 7ea2330

Browse files
Amanda ShiAmanda Shi
authored andcommitted
change the output lut numEntries
1 parent 9bc54e6 commit 7ea2330

File tree

3 files changed

+27
-27
lines changed

3 files changed

+27
-27
lines changed

src/main/scala/gemmini/MxConfigFragments.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ case class GemminiRequantizerConfig(
3838

3939
case class GemminiLUTConfig(
4040
numBits: Int = 96,
41-
numEntries: Int = 32,
41+
numEntries: Int = 16,
4242
numTables: Int = 3,
4343
rdataWidth: Int = 6,
4444
raddrWidth: Int = 4,

src/main/scala/gemmini/QuantLut.scala

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,15 @@ class QuantLut(
4444
val io = IO(new QuantLutIO(lutConfig, outputnumLanes, sp_bank_entries, sp_banks, sp_width, sp_width_projected, iterator_bitwidth))
4545
val rdataWidth = lutConfig.rdataWidth
4646
val raddrWidth = lutConfig.raddrWidth
47-
val lutCache_weight_0 = Seq.fill(32)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
48-
val lutCache_weight_1 = Seq.fill(32)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
49-
val lutCache_act_in_0 = Seq.fill(32)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
50-
val lutCache_act_in_1 = Seq.fill(32)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
51-
val lutCache_act_out_0 = Seq.fill(32)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
52-
val lutCache_act_out_1 = Seq.fill(32)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
47+
val lutCache_weight_0 = Seq.fill(lutConfig.numEntries)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
48+
val lutCache_weight_1 = Seq.fill(lutConfig.numEntries)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
49+
val lutCache_act_in_0 = Seq.fill(lutConfig.numEntries)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
50+
val lutCache_act_in_1 = Seq.fill(lutConfig.numEntries)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
51+
val lutCache_act_out_0 = Seq.fill(outputnumLanes)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
52+
val lutCache_act_out_1 = Seq.fill(outputnumLanes)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
5353
//io.lut_write.ready := !io.lutReadEnable
5454

55-
val lutCache_act_in = WireInit(VecInit.fill(32)(VecInit.fill(16)(0.U(rdataWidth.W))))
55+
val lutCache_act_in = WireInit(VecInit.fill(lutConfig.numEntries)(VecInit.fill(16)(0.U(rdataWidth.W))))
5656
val lutCache_act_in_flag = RegInit(false.B)
5757
val lutCache_act_in_buffer_0_read_enable = RegInit(false.B)
5858
val lutCache_act_in_buffer_1_read_enable = RegInit(false.B)
@@ -63,7 +63,7 @@ class QuantLut(
6363

6464
when(io.lut_write_act_in.fire){
6565
when(lutCache_act_in_flag === false.B){
66-
for (lane <- 0 until 32) {
66+
for (lane <- 0 until lutConfig.numEntries) {
6767
for (entry <- 0 until 16) {
6868
lutCache_act_in_0(lane)(entry) := io.lut_write_act_in.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
6969
//lutCache_act_in_0(lane)(entry) := io.lut_write_act_in.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -72,7 +72,7 @@ class QuantLut(
7272
lutCache_act_in_flag := ~lutCache_act_in_flag
7373
lutCache_act_in_buffer_0_read_enable := true.B
7474
}.otherwise {
75-
for (lane <- 0 until 32) {
75+
for (lane <- 0 until lutConfig.numEntries) {
7676
for (entry <- 0 until 16) {
7777
lutCache_act_in_1(lane)(entry) := io.lut_write_act_in.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
7878
//lutCache_act_in_1(lane)(entry) := io.lut_write_act_in.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -86,9 +86,9 @@ class QuantLut(
8686

8787
val lutCache_update_enable_act_in = WireInit(0.U(1.W))
8888
// FP6 only: tile = 32 elements, period = regularity/32 tiles
89-
lutCache_update_enable_act_in := (counter_i(log2Ceil(lut_update_regularity_act_in/32)-1, 0) === 0.U) && (counter_i_reg(log2Ceil(lut_update_regularity_act_in/32)-1, 0) === (lut_update_regularity_act_in/32-1).U)
89+
lutCache_update_enable_act_in := (io.counter_i(log2Ceil(lut_update_regularity_act_in/32)-1, 0) === 0.U) && (counter_i_reg(log2Ceil(lut_update_regularity_act_in/32)-1, 0) === (lut_update_regularity_act_in/32-1).U)
9090

91-
when(lutCache_update_enable_act_in){ //32 is the maxblock under fp6
91+
when(lutCache_update_enable_act_in === 1.U){ //32 is the maxblock under fp6
9292
when(lutCache_act_in_buffer_0_read_enable && (lutCache_act_in_buffer_select === false.B)){
9393
lutCache_act_in_buffer_0_read_enable := false.B
9494
}
@@ -108,7 +108,7 @@ class QuantLut(
108108

109109
io.lut_write_act_in.ready := !lutCache_act_in_buffer_0_read_enable || !lutCache_act_in_buffer_1_read_enable
110110

111-
val lutCache_weight = WireInit(VecInit.fill(32)(VecInit.fill(16)(0.U(rdataWidth.W))))
111+
val lutCache_weight = WireInit(VecInit.fill(lutConfig.numEntries)(VecInit.fill(16)(0.U(rdataWidth.W))))
112112
val lutCache_weight_flag = RegInit(false.B)
113113
val lutCache_weight_buffer_0_read_enable = RegInit(false.B)
114114
val lutCache_weight_buffer_1_read_enable = RegInit(false.B)
@@ -117,7 +117,7 @@ class QuantLut(
117117

118118
when(io.lut_write_weight.fire){
119119
when(lutCache_weight_flag === false.B){
120-
for (lane <- 0 until 32) {
120+
for (lane <- 0 until lutConfig.numEntries) {
121121
for (entry <- 0 until 16) {
122122
lutCache_weight_0(lane)(entry) := io.lut_write_weight.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
123123
//lutCache_weight_0(lane)(entry) := io.lut_write_weight.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -126,7 +126,7 @@ class QuantLut(
126126
lutCache_weight_flag := ~lutCache_weight_flag
127127
lutCache_weight_buffer_0_read_enable := true.B
128128
}.otherwise {
129-
for (lane <- 0 until 32) {
129+
for (lane <- 0 until lutConfig.numEntries) {
130130
for (entry <- 0 until 16) {
131131
lutCache_weight_1(lane)(entry) := io.lut_write_weight.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
132132
//lutCache_weight_1(lane)(entry) := io.lut_write_weight.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -139,9 +139,9 @@ class QuantLut(
139139

140140
val lutCache_update_enable_w_in = WireInit(0.U(1.W))
141141
// FP6 only: tile = 32 elements, period = regularity_w/32 tiles
142-
lutCache_update_enable_w_in := (counter_j(log2Ceil(lut_update_regularity_w/32)-1, 0) === 0.U) && (counter_j_reg(log2Ceil(lut_update_regularity_w/32)-1, 0) === (lut_update_regularity_w/32-1).U)
142+
lutCache_update_enable_w_in := (io.counter_j(log2Ceil(lut_update_regularity_w/32)-1, 0) === 0.U) && (counter_j_reg(log2Ceil(lut_update_regularity_w/32)-1, 0) === (lut_update_regularity_w/32-1).U)
143143

144-
when(lutCache_update_enable_w_in){
144+
when(lutCache_update_enable_w_in === 1.U){
145145
when(lutCache_weight_buffer_0_read_enable && (lutCache_weight_buffer_select === false.B)){
146146
lutCache_weight_buffer_0_read_enable := false.B
147147
}
@@ -161,15 +161,15 @@ class QuantLut(
161161

162162
io.lut_write_weight.ready := !lutCache_weight_buffer_0_read_enable || !lutCache_weight_buffer_1_read_enable
163163

164-
val lutCache_act_out = WireInit(VecInit.fill(32)(VecInit.fill(16)(0.U(rdataWidth.W))))
164+
val lutCache_act_out = WireInit(VecInit.fill(outputnumLanes)(VecInit.fill(16)(0.U(rdataWidth.W))))
165165
val lutCache_act_out_flag = RegInit(false.B)
166166
val lutCache_act_out_buffer_0_read_enable = RegInit(false.B)
167167
val lutCache_act_out_buffer_1_read_enable = RegInit(false.B)
168168
val lutCache_act_out_buffer_select = RegInit(false.B)
169169

170170
when(io.lut_write_act_out.fire){
171171
when(lutCache_act_out_flag === false.B){
172-
for (lane <- 0 until 32) {
172+
for (lane <- 0 until lutConfig.numEntries) {
173173
for (entry <- 0 until 16) {
174174
lutCache_act_out_0(lane)(entry) := io.lut_write_act_out.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
175175
//lutCache_act_out_0(lane)(entry) := io.lut_write_act_out.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -178,7 +178,7 @@ class QuantLut(
178178
lutCache_act_out_flag := ~lutCache_act_out_flag
179179
lutCache_act_out_buffer_0_read_enable := true.B
180180
}.otherwise {
181-
for (lane <- 0 until 32) {
181+
for (lane <- 0 until lutConfig.numEntries) {
182182
for (entry <- 0 until 16) {
183183
lutCache_act_out_1(lane)(entry) := io.lut_write_act_out.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
184184
//lutCache_act_out_1(lane)(entry) := io.lut_write_act_out.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -191,9 +191,9 @@ class QuantLut(
191191

192192
val lutCache_update_enable_act_out = WireInit(0.U(1.W))
193193
// FP6 only: tile = 32 elements, period = regularity_act_out/32 tiles
194-
lutCache_update_enable_act_out := (counter_i(log2Ceil(lut_update_regularity_act_out/32)-1, 0) === 0.U) && (counter_i_reg(log2Ceil(lut_update_regularity_act_out/32)-1, 0) === (lut_update_regularity_act_out/32-1).U)
194+
lutCache_update_enable_act_out := (io.counter_i(log2Ceil(lut_update_regularity_act_out/32)-1, 0) === 0.U) && (counter_i_reg(log2Ceil(lut_update_regularity_act_out/32)-1, 0) === (lut_update_regularity_act_out/32-1).U)
195195

196-
when(lutCache_update_enable_act_out){
196+
when(lutCache_update_enable_act_out === 1.U){
197197
when(lutCache_act_out_buffer_0_read_enable && (lutCache_act_out_buffer_select === false.B)){
198198
lutCache_act_out_buffer_0_read_enable := false.B
199199
}
@@ -261,7 +261,7 @@ class QuantLut(
261261
val chunk_4bit = io.spad_projected_data(i).resp.bits.data((k+1)*4-1, k*4)
262262
deprojected_bits(k) := lutCache_act_in(counter_act)(chunk_4bit)
263263
}
264-
when(counter_act === 31.U){
264+
when(counter_act === (lutConfig.numEntries -1).U){
265265
counter_act := 0.U
266266
}.otherwise{
267267
counter_act := counter_act + 1.U
@@ -271,7 +271,7 @@ class QuantLut(
271271
val chunk_4bit = io.spad_projected_data(i).resp.bits.data((k+1)*4-1, k*4)
272272
deprojected_bits(k) := lutCache_weight(counter_w)(chunk_4bit)
273273
}
274-
when(counter_w === 31.U){
274+
when(counter_w === (lutConfig.numEntries -1).U){
275275
counter_w := 0.U
276276
}.otherwise{
277277
counter_w := counter_w + 1.U

src/main/scala/gemmini/ScaleFactorMem.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,15 +347,15 @@ class ScalingFactorMem(
347347
weight_scales(i) := weight_bank_data_vec(i)
348348
// printf(p"[ScalingFactorMem] Read act_scales=${act_scales(i) }\n")
349349
// printf(p"[ScalingFactorMem] Read weight_scales=${weight_scales(i) }\n")
350+
io.read_resp.valid := combined_scales_valid
351+
io.read_resp.bits.combined_scales(i) := combined_scales_buffer(scale_counter)(i)
350352
}
351353
//printf(p"[ScalingFactorMem] Read scales from row=${read_addr_reg}\n")
352-
io.read_resp.valid := combined_scales_valid
353-
io.read_resp.bits.combined_scales(j) := combined_scales_buffer(scale_counter)(j)
354+
354355
when (((scale_counter === ((meshRows*tileRows-1).U) && fp8Mode) || (scale_counter === ((2*meshRows*tileRows-1).U) && !fp8Mode))) {
355356
scale_counter := 0.U
356357
}.otherwise{
357358
scale_counter := scale_counter +& 1.U
358359
}
359360
}
360-
361361
}

0 commit comments

Comments
 (0)