Skip to content

Commit b2e3849

Browse files
Amanda ShiAmanda Shi
authored andcommitted
fix some control bugs of scale mem
1 parent 0b289ee commit b2e3849

File tree

3 files changed

+115
-97
lines changed

3 files changed

+115
-97
lines changed

src/main/scala/gemmini/Controller.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,8 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
218218
})
219219

220220
if (!outer.config.testConfig) {
221-
mx_io.scale_mem <> spad.module.io.scale_mem.get
221+
mx_io.scale_mem_write_w <> spad.module.io.scale_mem_write_w.get
222+
mx_io.scale_mem_write_act <> spad.module.io.scale_mem_write_act.get
222223
}
223224

224225
mx_io.requant_out <> mx_requantizer.get.io.requant_data_out
@@ -308,13 +309,9 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
308309
requantized_writes(i).mask := VecInit(Seq.fill(requantized_writes(i).mask.length)(true.B))
309310
}
310311

311-
<<<<<<< HEAD
312312
if (!outer.config.testConfig) {
313313
mx_io.get.requant_in_gpu.ready := false.B
314314
}
315-
=======
316-
// mx_io.get.requant_in_gpu.ready := false.B
317-
>>>>>>> 0df4355 (change scale Mem as double RF buffer)
318315
mx_requantizer.get.io.requant_data_in.valid := false.B
319316
mx_requantizer.get.io.requant_data_in.bits := DontCare
320317
mx_requantizer.get.io.scaleMem_write.ready := false.B

src/main/scala/gemmini/ScaleFactorMem.scala

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ class ScalingFactorMem(
6666

6767
val combined_scales_buffer = WireInit(VecInit(Seq.fill(2*meshRows*tileRows)(
6868
VecInit(Seq.fill(2*meshRows*tileRows)(0.U(9.W))))))
69-
val combined_scales_buffer_r = RegInit(VecInit(Seq.fill(2*meshRows*tileRows)(
70-
VecInit(Seq.fill(2*meshRows*tileRows)(0.U(9.W))))))
69+
//val combined_scales_buffer_r = RegInit(VecInit(Seq.fill(2*meshRows*tileRows)(
70+
//VecInit(Seq.fill(2*meshRows*tileRows)(0.U(9.W))))))
7171
val combined_scales_valid = WireDefault(false.B)
7272
val bankDataT = Vec(bytesPerBank, UInt(8.W))
7373
val banks = Seq.fill(numBanks)(SyncReadMem(depth, bankDataT))
@@ -78,6 +78,7 @@ class ScalingFactorMem(
7878
val weight_buffer_1_read_enable = RegInit(false.B)
7979
val weight_write_counter = RegInit(0.U(8.W))
8080
val write_row_addr_w = io.scale_mem_write_w.bits.addr + (write_baseAddr_w >> (log2Ceil(2*meshRows*tileRows)))
81+
8182
val weight_buffer_write_full = RegInit(false.B)
8283
when(io.scale_mem_write_w.fire) {
8384
val write_bytes_low = io.scale_mem_write_w.bits.data(bytesPerBank * 8 - 1, 0).asTypeOf(bankDataT)
@@ -93,9 +94,10 @@ class ScalingFactorMem(
9394
banks(7).write(write_row_addr_w, write_bytes_high)
9495
weight_buffer_1_read_enable := true.B
9596
}
97+
weight_write_buffer_sel := ~weight_write_buffer_sel
9698
when((weight_write_counter === (depth - 1).U)){
9799
weight_write_counter := 0.U
98-
weight_write_buffer_sel := ~weight_write_buffer_sel
100+
99101
// when(weight_write_buffer_sel === false.B){
100102
// weight_buffer_0_read_enable := true.B
101103
// }.otherwise{
@@ -115,47 +117,42 @@ class ScalingFactorMem(
115117
when(act_write_buffer_sel === false.B) { //
116118
banks(0).write(write_row_addr_act, write_bytes_low)
117119
banks(1).write(write_row_addr_act, write_bytes_high)
118-
act_write_counter := act_write_counter + 1.U
119120
act_buffer_0_read_enable := true.B
120121
}.otherwise{
121122
act_write_counter := act_write_counter + 1.U
122123
banks(2).write(write_row_addr_act, write_bytes_low)
123124
banks(3).write(write_row_addr_act, write_bytes_high)
124125
act_buffer_1_read_enable := true.B
125126
}
127+
act_write_buffer_sel := ~act_write_buffer_sel
126128
when((act_write_counter === (depth - 1).U)){
127129
act_write_counter := 0.U
128-
act_write_buffer_sel := ~act_write_buffer_sel
129-
// when(act_write_buffer_sel === false.B){
130-
// act_buffer_0_read_enable := true.B
131-
// }.otherwise{
132-
// act_buffer_1_read_enable := true.B
133-
// }
134130
}
135131
}
136-
137-
138-
io.scale_mem_write_w.ready := (!weight_buffer_0_read_enable) || (!weight_buffer_1_read_enable)
139-
io.scale_mem_write_act.ready := (!act_buffer_0_read_enable) || (!act_buffer_1_read_enable)
132+
val max_block_fp8 = meshRows * tileRows
133+
val max_block_non_fp8 = 2*meshRows * tileRows
134+
val read_row_addr = WireDefault(counter_k >> log2Ceil(max_block_non_fp8))
135+
io.scale_mem_write_w.ready := ((!weight_buffer_0_read_enable) || (!weight_buffer_1_read_enable)) || (weight_write_counter ===0.U || (weight_write_counter > read_row_addr))
136+
io.scale_mem_write_act.ready := (!act_buffer_0_read_enable) || (!act_buffer_1_read_enable) || (act_write_counter ===0.U || ((act_write_counter === read_row_addr)))
140137
val act_read_buffer_select = RegInit(false.B)
141138
val weight_read_buffer_select = RegInit(false.B)
142139
val act_read_counter = RegInit(0.U(8.W))
143140
val weight_read_counter = RegInit(0.U(8.W))
144141

145-
when(counter_k(log2Ceil(depth)-1, 0) === (depth - 1).U && io.read_req.fire && io.read_req.bits.scaling_enable){
142+
when(io.read_req.fire && io.read_req.bits.scaling_enable){
146143
act_read_buffer_select := ~act_read_buffer_select
147144
weight_read_buffer_select := ~weight_read_buffer_select
148-
when(act_buffer_0_read_enable && (act_read_buffer_select === false.B)){
145+
when(act_buffer_0_read_enable && ((act_write_counter === read_row_addr))){
149146
act_buffer_0_read_enable := false.B
150147
}
151-
when(act_buffer_1_read_enable && (act_read_buffer_select === true.B)){
148+
when(act_buffer_1_read_enable && ((act_write_counter === read_row_addr))){
152149
act_buffer_1_read_enable := false.B
153150
}
154151

155-
when(weight_buffer_0_read_enable && (weight_read_buffer_select === false.B)){
152+
when(weight_buffer_0_read_enable && ((weight_write_counter === read_row_addr))){
156153
weight_buffer_0_read_enable := false.B
157154
}
158-
when(weight_buffer_1_read_enable && (weight_read_buffer_select === true.B)){
155+
when(weight_buffer_1_read_enable && ((weight_write_counter === read_row_addr))){
159156
weight_buffer_1_read_enable := false.B
160157
}
161158
}
@@ -169,11 +166,12 @@ class ScalingFactorMem(
169166
sum(8, 0)
170167
}
171168

172-
val read_fire = io.read_req.fire && io.read_req.bits.scaling_enable
173-
val read_fire_real = io.read_req.fire && io.read_req.bits.scaling_enable && (scale_counter === 0.U)
174-
val read_row_addr = counter_k
175-
val max_block_fp8 = meshRows * tileRows
176-
val max_block_non_fp8 = 2*meshRows * tileRows
169+
val read_fire = io.read_req.fire && io.read_req.bits.scaling_enable && ((act_buffer_0_read_enable && weight_buffer_0_read_enable) || (act_buffer_1_read_enable && weight_buffer_1_read_enable) )
170+
val read_fire_real = read_fire && (scale_counter === 0.U)
171+
172+
173+
174+
177175

178176
val act_bank_data_vec = WireInit(VecInit(Seq.fill(meshRows*tileRows*2)(0.U(8.W))))
179177
val weight_bank_data_vec = WireInit(VecInit(Seq.fill(meshRows*tileRows*2)(0.U(8.W))))
@@ -184,14 +182,14 @@ class ScalingFactorMem(
184182
weight_bank_sel := 0.U
185183

186184
val read_fire_banks = VecInit(Seq(
187-
read_fire_real && act_buffer_0_read_enable && (act_bank_sel === 0.U) , // bank 0
188-
read_fire_real && act_buffer_0_read_enable && (act_bank_sel === 1.U), // bank 1
189-
read_fire_real && act_buffer_1_read_enable && (act_bank_sel === 2.U), // bank 2
190-
read_fire_real && act_buffer_1_read_enable && (act_bank_sel === 3.U), // bank 3
191-
read_fire_real && weight_buffer_0_read_enable && (weight_bank_sel === 0.U), // bank 4
192-
read_fire_real && weight_buffer_0_read_enable && (weight_bank_sel === 1.U), // bank 5
193-
read_fire_real && weight_buffer_1_read_enable && (weight_bank_sel === 2.U), // bank 6
194-
read_fire_real && weight_buffer_1_read_enable && (weight_bank_sel === 3.U) // bank 7
185+
read_fire_real && act_buffer_0_read_enable && weight_buffer_0_read_enable && (act_bank_sel === 0.U) , // bank 0
186+
read_fire_real && act_buffer_0_read_enable && weight_buffer_0_read_enable && (act_bank_sel === 1.U), // bank 1
187+
read_fire_real && act_buffer_1_read_enable && weight_buffer_1_read_enable && (act_bank_sel === 2.U), // bank 2
188+
read_fire_real && act_buffer_1_read_enable && weight_buffer_1_read_enable && (act_bank_sel === 3.U), // bank 3
189+
read_fire_real && weight_buffer_0_read_enable && act_buffer_0_read_enable && (weight_bank_sel === 0.U), // bank 4
190+
read_fire_real && weight_buffer_0_read_enable && act_buffer_0_read_enable && (weight_bank_sel === 1.U), // bank 5
191+
read_fire_real && weight_buffer_1_read_enable && act_buffer_0_read_enable && (weight_bank_sel === 2.U), // bank 6
192+
read_fire_real && weight_buffer_1_read_enable && act_buffer_0_read_enable && (weight_bank_sel === 3.U) // bank 7
195193
))
196194

197195

@@ -267,7 +265,7 @@ class ScalingFactorMem(
267265
val read_fire_real_d1 = RegNext(read_fire_real, false.B)
268266
io.read_resp.bits.combined_scales.foreach(_ := 0.U)
269267
io.read_resp.valid := false.B
270-
io.read_req.ready := io.read_req.bits.scaling_enable && (scale_counter === 0.U)
268+
io.read_req.ready := io.read_req.bits.scaling_enable && ((act_buffer_0_read_enable && weight_buffer_0_read_enable) || (act_buffer_1_read_enable && weight_buffer_1_read_enable))
271269

272270
for(i <- 0 until 2*meshRows*tileRows) {
273271
for(j <- 0 until 2*meshRows*tileRows) {
@@ -282,30 +280,31 @@ class ScalingFactorMem(
282280
}
283281

284282

285-
when(read_fire_real_d1) {
286-
combined_scales_buffer_r := combined_scales_buffer
287-
}
283+
//when(read_fire_real_d1 && (scale_counter === 0.U) ) {
284+
val combined_scales_buffer_r = combined_scales_buffer
285+
288286

289287
when(read_fire_d1) {
290288
for(i <- 0 until 2*meshRows*tileRows) {
291289
act_scales(i) := act_bank_data_vec(i)
292290
weight_scales(i) := weight_bank_data_vec(i)
293-
//printf(p"[ScalingFactorMem] Read act_scales=${act_scales(i) }\n")
294-
//printf(p"[ScalingFactorMem] Read weight_scales=${weight_scales(i) }\n")
291+
// printf(p"[ScalingFactorMem] Read act_scales=${act_scales(i) }\n")
292+
// printf(p"[ScalingFactorMem] Read weight_scales=${weight_scales(i) }\n")
295293
}
296294
//printf(p"[ScalingFactorMem] Read scales from row=${read_addr_reg}\n")
295+
io.read_resp.valid := combined_scales_valid
297296
when (((scale_counter === ((meshRows*tileRows-1).U) && fp8Mode) || (scale_counter === ((2*meshRows*tileRows-1).U) && !fp8Mode))) {
298297
scale_counter := 0.U
299298
}.otherwise{
300-
io.read_resp.valid := combined_scales_valid
301299
scale_counter := scale_counter +& 1.U
302300
when(fp8Mode){
303301
for(j <- 0 until meshRows*tileRows){
304302
when(scale_counter === 0.U){
305303
io.read_resp.bits.combined_scales(j) := combined_scales_buffer(0)(j)
306-
//printf(p"[ScalingFactorMem] Read scale from row=${scale_counter}, and get the scale=${io.read_resp.bits.combined_scales(j)}\n")
304+
// printf(p"[ScalingFactorMem] Read scale from row=${scale_counter}, and get the scale=${io.read_resp.bits.combined_scales(j)}\n")
307305
}.otherwise{
308306
io.read_resp.bits.combined_scales(j) := combined_scales_buffer_r(scale_counter)(j)
307+
// printf(p"[ScalingFactorMem] Read scale from row=${scale_counter}, and get the scale=${io.read_resp.bits.combined_scales(j)}\n")
309308
}
310309
}
311310
}.otherwise{
@@ -314,7 +313,7 @@ class ScalingFactorMem(
314313
io.read_resp.bits.combined_scales(j) := combined_scales_buffer(0)(j)
315314
//printf(p"[ScalingFactorMem] Read scale from row=${scale_counter}, weight_row=${weight_row_counter}, and get the scale=${io.read_resp.bits.combined_scales(j)}\n")
316315
}.otherwise{
317-
io.read_resp.bits.combined_scales(j) := combined_scales_buffer_r(scale_counter + (meshRows*tileRows).U)(j)
316+
io.read_resp.bits.combined_scales(j) := combined_scales_buffer_r(scale_counter)(j)
318317
}
319318
}
320319
}

0 commit comments

Comments
 (0)