@@ -24,6 +24,9 @@ class ScalingFactorMemIO(addrWidth: Int, dataWidth: Int, numRows: Int, numCols:
2424 val counter_i = Input (UInt (16 .W ))
2525 val counter_j = Input (UInt (16 .W ))
2626 val counter_k = Input (UInt (16 .W ))
27+ val i = Input (UInt (16 .W ))
28+ val j = Input (UInt (16 .W ))
29+ val k = Input (UInt (16 .W ))
2730}
2831
2932class ScalingFactorMem (
@@ -52,12 +55,6 @@ class ScalingFactorMem(
5255 tileRows,
5356 ))
5457
55- val counter_a_fire = io.scaleMemCntl.counter_a
56- val counter_b_fire = io.scaleMemCntl.counter_b
57- val fire_a = io.scaleMemCntl.fire_a
58- val fire_b = io.scaleMemCntl.fire_b
59- val write_baseAddr_act = io.scaleMemCntl.baseAddress_act
60- val write_baseAddr_w = io.scaleMemCntl.baseAddress_w
6158 val initByte = 0x7e .U (8 .W )
6259 val defaultRow = VecInit (Seq .fill(bytesPerBank)(initByte))
6360 val counter_i = io.counter_i
@@ -115,7 +112,6 @@ class ScalingFactorMem(
115112 val write_bytes_low = write_weight_full_row(bytesPerBank * 8 - 1 , 0 ).asTypeOf(bankDataT)
116113 val write_bytes_high = write_weight_full_row(bytesPerBank * 2 * 8 - 1 , bytesPerBank * 8 ).asTypeOf(bankDataT)
117114 when(weight_write_buffer_sel === false .B ) {
118- // weight_write_counter := weight_write_counter + 1.U
119115 banks(4 ).write(write_row_addr_w_reg, write_bytes_low)
120116 banks(5 ).write(write_row_addr_w_reg, write_bytes_high)
121117 weight_buffer_0_read_enable := true .B
@@ -183,9 +179,10 @@ class ScalingFactorMem(
183179
184180 val max_block_fp8 = meshRows * tileRows
185181 val max_block_non_fp8 = 2 * meshRows * tileRows
186- val read_row_addr = WireDefault (counter_k >> log2Ceil(max_block_non_fp8))
187- io.scale_mem_write_w.ready := (weight_write_counter === 0 .U || (weight_write_counter(log2Ceil(depth)- 1 ,0 ) =/= read_row_addr)) || (! weight_buffer_0_read_enable) || (! weight_buffer_1_read_enable)
188- io.scale_mem_write_act.ready := (act_write_counter === 0 .U || ((act_write_counter(log2Ceil(depth)- 1 ,0 ) =/= read_row_addr))) || (! act_buffer_0_read_enable) || (! act_buffer_1_read_enable)
182+ val read_row_addr_act = WireDefault (io.i * (counter_k >> log2Ceil(max_block_non_fp8)) + (counter_i >> (log2Ceil(bytesPerBank* numBanks/ 2 ))))
183+ val read_row_addr_w = WireDefault (io.j * (counter_k >> log2Ceil(max_block_non_fp8)) + (counter_j >> (log2Ceil(bytesPerBank* numBanks/ 2 ))))
184+ io.scale_mem_write_w.ready := (weight_write_counter === 0 .U || (weight_write_counter(log2Ceil(depth)- 1 ,0 ) =/= read_row_addr_w)) || (! weight_buffer_0_read_enable) || (! weight_buffer_1_read_enable)
185+ io.scale_mem_write_act.ready := (act_write_counter === 0 .U || ((act_write_counter(log2Ceil(depth)- 1 ,0 ) =/= read_row_addr_act))) || (! act_buffer_0_read_enable) || (! act_buffer_1_read_enable)
189186 val act_read_buffer_select = RegInit (false .B )
190187 val weight_read_buffer_select = RegInit (false .B )
191188 val act_read_counter = RegInit (0 .U (8 .W ))
@@ -194,17 +191,17 @@ class ScalingFactorMem(
194191 when(io.read_req.fire && io.read_req.bits.scaling_enable){
195192 act_read_buffer_select := ~ act_read_buffer_select
196193 weight_read_buffer_select := ~ weight_read_buffer_select
197- when(act_buffer_0_read_enable && ((act_write_counter(log2Ceil(depth)- 1 ,0 ) === read_row_addr ))){
194+ when(act_buffer_0_read_enable && ((act_write_counter(log2Ceil(depth)- 1 ,0 ) === read_row_addr_act ))){
198195 act_buffer_0_read_enable := false .B
199196 }
200- when(act_buffer_1_read_enable && ((act_write_counter(log2Ceil(depth)- 1 ,0 ) === read_row_addr ))){
197+ when(act_buffer_1_read_enable && ((act_write_counter(log2Ceil(depth)- 1 ,0 ) === read_row_addr_act ))){
201198 act_buffer_1_read_enable := false .B
202199 }
203200
204- when(weight_buffer_0_read_enable && ((weight_write_counter(log2Ceil(depth)- 1 ,0 ) === read_row_addr ))){
201+ when(weight_buffer_0_read_enable && ((weight_write_counter(log2Ceil(depth)- 1 ,0 ) === read_row_addr_w ))){
205202 weight_buffer_0_read_enable := false .B
206203 }
207- when(weight_buffer_1_read_enable && ((weight_write_counter(log2Ceil(depth)- 1 ,0 ) === read_row_addr ))){
204+ when(weight_buffer_1_read_enable && ((weight_write_counter(log2Ceil(depth)- 1 ,0 ) === read_row_addr_w ))){
208205 weight_buffer_1_read_enable := false .B
209206 }
210207 }
@@ -234,79 +231,79 @@ class ScalingFactorMem(
234231 weight_bank_sel := 0 .U
235232
236233 val read_fire_banks = VecInit (Seq (
237- read_fire_real && act_buffer_0_read_enable && weight_buffer_0_read_enable && (act_bank_sel === 0 .U ) , // bank 0
238- read_fire_real && act_buffer_0_read_enable && weight_buffer_0_read_enable && (act_bank_sel === 1 .U ), // bank 1
239- read_fire_real && act_buffer_1_read_enable && weight_buffer_1_read_enable && (act_bank_sel === 2 .U ), // bank 2
240- read_fire_real && act_buffer_1_read_enable && weight_buffer_1_read_enable && (act_bank_sel === 3 .U ), // bank 3
241- read_fire_real && weight_buffer_0_read_enable && act_buffer_0_read_enable && (weight_bank_sel === 0 .U ), // bank 4
242- read_fire_real && weight_buffer_0_read_enable && act_buffer_0_read_enable && (weight_bank_sel === 1 .U ), // bank 5
243- read_fire_real && weight_buffer_1_read_enable && weight_buffer_1_read_enable && (weight_bank_sel === 2 .U ), // bank 6
244- read_fire_real && weight_buffer_1_read_enable && weight_buffer_1_read_enable && (weight_bank_sel === 3 .U ) // bank 7
234+ read_fire_real && act_buffer_0_read_enable && (act_bank_sel === 0 .U ) , // bank 0
235+ read_fire_real && act_buffer_0_read_enable && (act_bank_sel === 1 .U ), // bank 1
236+ read_fire_real && act_buffer_1_read_enable && (act_bank_sel === 2 .U ), // bank 2
237+ read_fire_real && act_buffer_1_read_enable && (act_bank_sel === 3 .U ), // bank 3
238+ read_fire_real && act_buffer_0_read_enable && (weight_bank_sel === 0 .U ), // bank 4
239+ read_fire_real && act_buffer_0_read_enable && (weight_bank_sel === 1 .U ), // bank 5
240+ read_fire_real && weight_buffer_1_read_enable && (weight_bank_sel === 2 .U ), // bank 6
241+ read_fire_real && weight_buffer_1_read_enable && (weight_bank_sel === 3 .U ) // bank 7
245242 ))
246243
247244
248- val bank_data = VecInit ((0 until 8 ).map { i => if (testConfig) defaultRow else banks(i).read(read_row_addr , read_fire_banks(i))})
249-
245+ val bank_data_0 = VecInit ((0 until 4 ).map { i => if (testConfig) defaultRow else banks(i).read(read_row_addr_act , read_fire_banks(i))})
246+ val bank_data_1 = VecInit (( 0 until 4 ).map { i => if (testConfig) defaultRow else banks(i + 4 ).read(read_row_addr_w, read_fire_banks(i + 4 ))})
250247 when(fp8Mode){
251248 act_bank_sel := counter_i(1 + log2Ceil(max_block_fp8), log2Ceil(max_block_fp8))
252249 weight_bank_sel := counter_j(1 + log2Ceil(max_block_fp8), log2Ceil(max_block_fp8))
253250 when(act_bank_sel === 0 .U && (act_buffer_0_read_enable)) {
254251 for (i <- 0 until meshRows* tileRows) {
255- act_bank_data_vec(i) := bank_data (0 )(i)
252+ act_bank_data_vec(i) := bank_data_0 (0 )(i)
256253 }
257254 }.elsewhen( act_bank_sel === 1 .U && (act_buffer_0_read_enable)) {
258255 for (i <- 0 until meshRows* tileRows) {
259- act_bank_data_vec(i) := bank_data (1 )(i)
256+ act_bank_data_vec(i) := bank_data_0 (1 )(i)
260257 }
261258 }.elsewhen( act_bank_sel === 2 .U && (act_buffer_1_read_enable)) {
262259 for (i <- 0 until meshRows* tileRows) {
263- act_bank_data_vec(i) := bank_data (2 )(i)
260+ act_bank_data_vec(i) := bank_data_0 (2 )(i)
264261 }
265262 }.elsewhen( act_bank_sel === 3 .U && (act_buffer_1_read_enable)) {
266263 for (i <- 0 until meshRows* tileRows) {
267- act_bank_data_vec(i) := bank_data (3 )(i)
264+ act_bank_data_vec(i) := bank_data_0 (3 )(i)
268265 }
269266 }
270267 when(weight_bank_sel === 0 .U && (weight_buffer_0_read_enable)) {
271268 for (i <- 0 until meshRows* tileRows) {
272- weight_bank_data_vec(i) := bank_data( 4 )(i)
269+ weight_bank_data_vec(i) := bank_data_1( 0 )(i)
273270 }
274271 }.elsewhen( weight_bank_sel === 1 .U && (weight_buffer_0_read_enable)) {
275272 for (i <- 0 until meshRows* tileRows) {
276- weight_bank_data_vec(i) := bank_data( 5 )(i)
273+ weight_bank_data_vec(i) := bank_data_1( 1 )(i)
277274 }
278275 }.elsewhen( weight_bank_sel === 2 .U && (weight_buffer_1_read_enable)) {
279276 for (i <- 0 until meshRows* tileRows) {
280- weight_bank_data_vec(i) := bank_data( 6 )(i)
277+ weight_bank_data_vec(i) := bank_data_1( 2 )(i)
281278 }
282279 }.elsewhen( weight_bank_sel === 3 .U && (weight_buffer_1_read_enable)) {
283280 for (i <- 0 until meshRows* tileRows) {
284- weight_bank_data_vec(i) := bank_data( 7 )(i)
281+ weight_bank_data_vec(i) := bank_data_1( 3 )(i)
285282 }
286283 }
287284 }.otherwise{
288285 act_bank_sel := Cat (0 .U (1 .W ), counter_i(log2Ceil(max_block_non_fp8)))
289286 weight_bank_sel := Cat (0 .U (1 .W ),counter_j(log2Ceil(max_block_non_fp8)))
290- when(act_bank_sel(0 ) === 0 .U && (act_buffer_0_read_enable) && (weight_buffer_0_read_enable) ) {
287+ when(act_bank_sel(0 ) === 0 .U && (act_buffer_0_read_enable)) {
291288 for (i <- 0 until meshRows* tileRows) {
292- act_bank_data_vec(i) := bank_data (0 )(i)
293- act_bank_data_vec(meshRows* tileRows+ i) := bank_data (1 )(i)
289+ act_bank_data_vec(i) := bank_data_0 (0 )(i)
290+ act_bank_data_vec(meshRows* tileRows+ i) := bank_data_0 (1 )(i)
294291 }
295- }.elsewhen(act_bank_sel(0 ) === 1 .U && (act_buffer_0_read_enable) && (weight_buffer_0_read_enable) ) {
292+ }.elsewhen(act_bank_sel(0 ) === 1 .U && (act_buffer_0_read_enable) ) {
296293 for (i <- 0 until meshRows* tileRows) {
297- act_bank_data_vec(i) := bank_data (2 )(i)
298- act_bank_data_vec(meshRows* tileRows+ i) := bank_data (3 )(i)
294+ act_bank_data_vec(i) := bank_data_0 (2 )(i)
295+ act_bank_data_vec(meshRows* tileRows+ i) := bank_data_0 (3 )(i)
299296 }
300297 }
301- when(counter_j( 1 + log2Ceil(max_block_non_fp8), log2Ceil(max_block_non_fp8)) === 1 .U && (weight_buffer_0_read_enable)) {
298+ when(weight_bank_sel === 0 .U && (weight_buffer_0_read_enable)) {
302299 for (i <- 0 until meshRows* tileRows) {
303- weight_bank_data_vec(i) := bank_data( 4 )(i)
304- weight_bank_data_vec(meshRows* tileRows+ i) := bank_data( 5 )(i)
300+ weight_bank_data_vec(i) := bank_data_1( 0 )(i)
301+ weight_bank_data_vec(meshRows* tileRows+ i) := bank_data_1( 1 )(i)
305302 }
306- }.elsewhen( counter_j( 1 + log2Ceil(max_block_non_fp8), log2Ceil(max_block_non_fp8)) === 1 .U && (weight_buffer_1_read_enable)) {
303+ }.elsewhen( weight_bank_sel === 1 .U && (weight_buffer_1_read_enable)) {
307304 for (i <- 0 until meshRows* tileRows) {
308- weight_bank_data_vec(i) := bank_data( 6 )(i)
309- weight_bank_data_vec(meshRows* tileRows+ i) := bank_data( 7 )(i)
305+ weight_bank_data_vec(i) := bank_data_1( 2 )(i)
306+ weight_bank_data_vec(meshRows* tileRows+ i) := bank_data_1( 3 )(i)
310307 }
311308 }
312309 }
0 commit comments