@@ -66,8 +66,8 @@ class ScalingFactorMem(
6666
6767 val combined_scales_buffer = WireInit (VecInit (Seq .fill(2 * meshRows* tileRows)(
6868 VecInit (Seq .fill(2 * meshRows* tileRows)(0 .U (9 .W ))))))
69- val combined_scales_buffer_r = RegInit (VecInit (Seq .fill(2 * meshRows* tileRows)(
70- VecInit (Seq .fill(2 * meshRows* tileRows)(0 .U (9 .W ))))))
69+ // val combined_scales_buffer_r = RegInit(VecInit(Seq.fill(2*meshRows*tileRows)(
70+ // VecInit(Seq.fill(2*meshRows*tileRows)(0.U(9.W))))))
7171 val combined_scales_valid = WireDefault (false .B )
7272 val bankDataT = Vec (bytesPerBank, UInt (8 .W ))
7373 val banks = Seq .fill(numBanks)(SyncReadMem (depth, bankDataT))
@@ -78,6 +78,7 @@ class ScalingFactorMem(
7878 val weight_buffer_1_read_enable = RegInit (false .B )
7979 val weight_write_counter = RegInit (0 .U (8 .W ))
8080 val write_row_addr_w = io.scale_mem_write_w.bits.addr + (write_baseAddr_w >> (log2Ceil(2 * meshRows* tileRows)))
81+
8182 val weight_buffer_write_full = RegInit (false .B )
8283 when(io.scale_mem_write_w.fire) {
8384 val write_bytes_low = io.scale_mem_write_w.bits.data(bytesPerBank * 8 - 1 , 0 ).asTypeOf(bankDataT)
@@ -93,9 +94,10 @@ class ScalingFactorMem(
9394 banks(7 ).write(write_row_addr_w, write_bytes_high)
9495 weight_buffer_1_read_enable := true .B
9596 }
97+ weight_write_buffer_sel := ~ weight_write_buffer_sel
9698 when((weight_write_counter === (depth - 1 ).U )){
9799 weight_write_counter := 0 .U
98- weight_write_buffer_sel := ~ weight_write_buffer_sel
100+
99101 // when(weight_write_buffer_sel === false.B){
100102 // weight_buffer_0_read_enable := true.B
101103 // }.otherwise{
@@ -115,47 +117,42 @@ class ScalingFactorMem(
115117 when(act_write_buffer_sel === false .B ) { // ✓
116118 banks(0 ).write(write_row_addr_act, write_bytes_low)
117119 banks(1 ).write(write_row_addr_act, write_bytes_high)
118- act_write_counter := act_write_counter + 1 .U
119120 act_buffer_0_read_enable := true .B
120121 }.otherwise{
121122 act_write_counter := act_write_counter + 1 .U
122123 banks(2 ).write(write_row_addr_act, write_bytes_low)
123124 banks(3 ).write(write_row_addr_act, write_bytes_high)
124125 act_buffer_1_read_enable := true .B
125126 }
127+ act_write_buffer_sel := ~ act_write_buffer_sel
126128 when((act_write_counter === (depth - 1 ).U )){
127129 act_write_counter := 0 .U
128- act_write_buffer_sel := ~ act_write_buffer_sel
129- // when(act_write_buffer_sel === false.B){
130- // act_buffer_0_read_enable := true.B
131- // }.otherwise{
132- // act_buffer_1_read_enable := true.B
133- // }
134130 }
135131 }
136-
137-
138- io.scale_mem_write_w.ready := (! weight_buffer_0_read_enable) || (! weight_buffer_1_read_enable)
139- io.scale_mem_write_act.ready := (! act_buffer_0_read_enable) || (! act_buffer_1_read_enable)
132+ val max_block_fp8 = meshRows * tileRows
133+ val max_block_non_fp8 = 2 * meshRows * tileRows
134+ val read_row_addr = WireDefault (counter_k >> log2Ceil(max_block_non_fp8))
135+ io.scale_mem_write_w.ready := ((! weight_buffer_0_read_enable) || (! weight_buffer_1_read_enable)) || (weight_write_counter === 0 .U || (weight_write_counter > read_row_addr))
136+ io.scale_mem_write_act.ready := (! act_buffer_0_read_enable) || (! act_buffer_1_read_enable) || (act_write_counter === 0 .U || ((act_write_counter === read_row_addr)))
140137 val act_read_buffer_select = RegInit (false .B )
141138 val weight_read_buffer_select = RegInit (false .B )
142139 val act_read_counter = RegInit (0 .U (8 .W ))
143140 val weight_read_counter = RegInit (0 .U (8 .W ))
144141
145- when(counter_k(log2Ceil(depth) - 1 , 0 ) === (depth - 1 ). U && io.read_req.fire && io.read_req.bits.scaling_enable){
142+ when(io.read_req.fire && io.read_req.bits.scaling_enable){
146143 act_read_buffer_select := ~ act_read_buffer_select
147144 weight_read_buffer_select := ~ weight_read_buffer_select
148- when(act_buffer_0_read_enable && (act_read_buffer_select === false . B )){
145+ when(act_buffer_0_read_enable && ((act_write_counter === read_row_addr) )){
149146 act_buffer_0_read_enable := false .B
150147 }
151- when(act_buffer_1_read_enable && (act_read_buffer_select === true . B )){
148+ when(act_buffer_1_read_enable && ((act_write_counter === read_row_addr) )){
152149 act_buffer_1_read_enable := false .B
153150 }
154151
155- when(weight_buffer_0_read_enable && (weight_read_buffer_select === false . B )){
152+ when(weight_buffer_0_read_enable && ((weight_write_counter === read_row_addr) )){
156153 weight_buffer_0_read_enable := false .B
157154 }
158- when(weight_buffer_1_read_enable && (weight_read_buffer_select === true . B )){
155+ when(weight_buffer_1_read_enable && ((weight_write_counter === read_row_addr) )){
159156 weight_buffer_1_read_enable := false .B
160157 }
161158 }
@@ -169,11 +166,12 @@ class ScalingFactorMem(
169166 sum(8 , 0 )
170167 }
171168
172- val read_fire = io.read_req.fire && io.read_req.bits.scaling_enable
173- val read_fire_real = io.read_req.fire && io.read_req.bits.scaling_enable && (scale_counter === 0 .U )
174- val read_row_addr = counter_k
175- val max_block_fp8 = meshRows * tileRows
176- val max_block_non_fp8 = 2 * meshRows * tileRows
169+ val read_fire = io.read_req.fire && io.read_req.bits.scaling_enable && ((act_buffer_0_read_enable && weight_buffer_0_read_enable) || (act_buffer_1_read_enable && weight_buffer_1_read_enable) )
170+ val read_fire_real = read_fire && (scale_counter === 0 .U )
171+
172+
173+
174+
177175
178176 val act_bank_data_vec = WireInit (VecInit (Seq .fill(meshRows* tileRows* 2 )(0 .U (8 .W ))))
179177 val weight_bank_data_vec = WireInit (VecInit (Seq .fill(meshRows* tileRows* 2 )(0 .U (8 .W ))))
@@ -184,14 +182,14 @@ class ScalingFactorMem(
184182 weight_bank_sel := 0 .U
185183
186184 val read_fire_banks = VecInit (Seq (
187- read_fire_real && act_buffer_0_read_enable && (act_bank_sel === 0 .U ) , // bank 0
188- read_fire_real && act_buffer_0_read_enable && (act_bank_sel === 1 .U ), // bank 1
189- read_fire_real && act_buffer_1_read_enable && (act_bank_sel === 2 .U ), // bank 2
190- read_fire_real && act_buffer_1_read_enable && (act_bank_sel === 3 .U ), // bank 3
191- read_fire_real && weight_buffer_0_read_enable && (weight_bank_sel === 0 .U ), // bank 4
192- read_fire_real && weight_buffer_0_read_enable && (weight_bank_sel === 1 .U ), // bank 5
193- read_fire_real && weight_buffer_1_read_enable && (weight_bank_sel === 2 .U ), // bank 6
194- read_fire_real && weight_buffer_1_read_enable && (weight_bank_sel === 3 .U ) // bank 7
185+ read_fire_real && act_buffer_0_read_enable && weight_buffer_0_read_enable && (act_bank_sel === 0 .U ) , // bank 0
186+ read_fire_real && act_buffer_0_read_enable && weight_buffer_0_read_enable && (act_bank_sel === 1 .U ), // bank 1
187+ read_fire_real && act_buffer_1_read_enable && weight_buffer_1_read_enable && (act_bank_sel === 2 .U ), // bank 2
188+ read_fire_real && act_buffer_1_read_enable && weight_buffer_1_read_enable && (act_bank_sel === 3 .U ), // bank 3
189+ read_fire_real && weight_buffer_0_read_enable && act_buffer_0_read_enable && (weight_bank_sel === 0 .U ), // bank 4
190+ read_fire_real && weight_buffer_0_read_enable && act_buffer_0_read_enable && (weight_bank_sel === 1 .U ), // bank 5
191+ read_fire_real && weight_buffer_1_read_enable && act_buffer_0_read_enable && (weight_bank_sel === 2 .U ), // bank 6
192+ read_fire_real && weight_buffer_1_read_enable && act_buffer_0_read_enable && (weight_bank_sel === 3 .U ) // bank 7
195193 ))
196194
197195
@@ -267,7 +265,7 @@ class ScalingFactorMem(
267265 val read_fire_real_d1 = RegNext (read_fire_real, false .B )
268266 io.read_resp.bits.combined_scales.foreach(_ := 0 .U )
269267 io.read_resp.valid := false .B
270- io.read_req.ready := io.read_req.bits.scaling_enable && (scale_counter === 0 . U )
268+ io.read_req.ready := io.read_req.bits.scaling_enable && ((act_buffer_0_read_enable && weight_buffer_0_read_enable) || (act_buffer_1_read_enable && weight_buffer_1_read_enable) )
271269
272270 for (i <- 0 until 2 * meshRows* tileRows) {
273271 for (j <- 0 until 2 * meshRows* tileRows) {
@@ -282,30 +280,31 @@ class ScalingFactorMem(
282280 }
283281
284282
285- when(read_fire_real_d1) {
286- combined_scales_buffer_r : = combined_scales_buffer
287- }
283+ // when(read_fire_real_d1 && (scale_counter === 0.U) ) {
284+ val combined_scales_buffer_r = combined_scales_buffer
285+
288286
289287 when(read_fire_d1) {
290288 for (i <- 0 until 2 * meshRows* tileRows) {
291289 act_scales(i) := act_bank_data_vec(i)
292290 weight_scales(i) := weight_bank_data_vec(i)
293- // printf(p"[ScalingFactorMem] Read act_scales=${act_scales(i) }\n")
294- // printf(p"[ScalingFactorMem] Read weight_scales=${weight_scales(i) }\n")
291+ // printf(p"[ScalingFactorMem] Read act_scales=${act_scales(i) }\n")
292+ // printf(p"[ScalingFactorMem] Read weight_scales=${weight_scales(i) }\n")
295293 }
296294 // printf(p"[ScalingFactorMem] Read scales from row=${read_addr_reg}\n")
295+ io.read_resp.valid := combined_scales_valid
297296 when (((scale_counter === ((meshRows* tileRows- 1 ).U ) && fp8Mode) || (scale_counter === ((2 * meshRows* tileRows- 1 ).U ) && ! fp8Mode))) {
298297 scale_counter := 0 .U
299298 }.otherwise{
300- io.read_resp.valid := combined_scales_valid
301299 scale_counter := scale_counter +& 1 .U
302300 when(fp8Mode){
303301 for (j <- 0 until meshRows* tileRows){
304302 when(scale_counter === 0 .U ){
305303 io.read_resp.bits.combined_scales(j) := combined_scales_buffer(0 )(j)
306- // printf(p"[ScalingFactorMem] Read scale from row=${scale_counter}, and get the scale=${io.read_resp.bits.combined_scales(j)}\n")
304+ // printf(p"[ScalingFactorMem] Read scale from row=${scale_counter}, and get the scale=${io.read_resp.bits.combined_scales(j)}\n")
307305 }.otherwise{
308306 io.read_resp.bits.combined_scales(j) := combined_scales_buffer_r(scale_counter)(j)
307+ // printf(p"[ScalingFactorMem] Read scale from row=${scale_counter}, and get the scale=${io.read_resp.bits.combined_scales(j)}\n")
309308 }
310309 }
311310 }.otherwise{
@@ -314,7 +313,7 @@ class ScalingFactorMem(
314313 io.read_resp.bits.combined_scales(j) := combined_scales_buffer(0 )(j)
315314 // printf(p"[ScalingFactorMem] Read scale from row=${scale_counter}, weight_row=${weight_row_counter}, and get the scale=${io.read_resp.bits.combined_scales(j)}\n")
316315 }.otherwise{
317- io.read_resp.bits.combined_scales(j) := combined_scales_buffer_r(scale_counter + (meshRows * tileRows). U )(j)
316+ io.read_resp.bits.combined_scales(j) := combined_scales_buffer_r(scale_counter)(j)
318317 }
319318 }
320319 }
0 commit comments