Skip to content

Commit af9b8b4

Browse files
Amanda ShiAmanda Shi
authored andcommitted
adapt the reading address for scaleme to support any tile size
1 parent 350d8d6 commit af9b8b4

File tree

7 files changed

+98
-71
lines changed

7 files changed

+98
-71
lines changed

src/main/scala/gemmini/AccumulatorMem.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ class AccumulatorMemIO [T <: Data: Arithmetic, U <: Data](n: Int, t: Vec[Vec[T]]
6363
val counter_i = Input(UInt(16.W)) //for scaling factor memory control
6464
val counter_j = Input(UInt(16.W)) //for scaling factor memory control
6565
val counter_k = Input(UInt(16.W)) //for scaling factor memory control
66-
66+
val i = Input(UInt(16.W)) //for scaling factor memory control
67+
val j = Input(UInt(16.W)) //for scaling factor memory control
68+
val k = Input(UInt(16.W)) //for scaling factor memory control
6769
val dataType = Input(UInt(2.W)) //this is the input mxformat datatype
6870
val scale_mem_write_act = if (use_mx_scaling) {
6971
Some(Flipped(Decoupled(new ScalingFactorWriteReq(13, 64))))
@@ -254,7 +256,10 @@ class AccumulatorMem[T <: Data, U <: Data](
254256
scale_mem.io.counter_i := io.counter_i
255257
scale_mem.io.counter_j := io.counter_j
256258
scale_mem.io.counter_k := io.counter_k
257-
259+
scale_mem.io.i := io.i
260+
scale_mem.io.j := io.j
261+
scale_mem.io.k := io.k
262+
258263
scale_mem.io.scaleMemCntl <> io.scaleMemCntl.get
259264
scale_mem.io.read_req.valid := false.B
260265
scale_mem.io.read_req.bits.addr := DontCare

src/main/scala/gemmini/Controller.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,9 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
544544
spad.module.io.counter_i := loop_matmul.io.counter_i
545545
spad.module.io.counter_j := loop_matmul.io.counter_j
546546
spad.module.io.counter_k := loop_matmul.io.counter_k
547-
547+
spad.module.io.i := loop_matmul.io.i
548+
spad.module.io.j := loop_matmul.io.j
549+
spad.module.io.k := loop_matmul.io.k
548550
val unrolled_cmd = Queue(loop_cmd)
549551
unrolled_cmd.ready := false.B
550552
counters.io.event_io.connectEventSignal(CounterEvent.LOOP_MATMUL_ACTIVE_CYCLES, loop_matmul_unroller_busy)

src/main/scala/gemmini/LoopMatmul.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,9 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size
909909
val counter_i = Output(UInt(16.W))
910910
val counter_j = Output(UInt(16.W))
911911
val counter_k = Output(UInt(16.W))
912+
val i = Output(UInt(16.W))
913+
val j = Output(UInt(16.W))
914+
val k = Output(UInt(16.W))
912915
})
913916

914917
// Create states
@@ -940,7 +943,9 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size
940943
io.counter_i := ex.io.i
941944
io.counter_j := ex.io.j
942945
io.counter_k := ex.io.k
943-
946+
io.i := loop_being_configured.max_i
947+
io.j := loop_being_configured.max_j
948+
io.k := loop_being_configured.max_k
944949
io.busy := cmd.valid || loop_configured
945950

946951
io.completed := 0.U.asTypeOf(io.completed.cloneType)

src/main/scala/gemmini/MxConfigFragments.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ case class GemminiRequantizerConfig(
3131
minOutputBits: Int = 4,
3232
maxOutputBits: Int = 8,
3333
outputIdBits: Int = 3,
34-
lutUpdateRegularityW : Int = 128, // means how many elements update once the lut
34+
lutUpdateRegularityW : Int = 128, // means how many elements updatScalingFactorCntle once the lut
3535
lutUpdateRegularityActIn : Int = 128,
3636
lutUpdateRegularityActOut : Int = 128,
3737
)
3838

3939
case class GemminiLUTConfig(
4040
numBits: Int = 96,
41-
numEntries: Int = 1,
41+
numEntries: Int = 32,
4242
numTables: Int = 3,
4343
rdataWidth: Int = 6,
4444
raddrWidth: Int = 4,
@@ -68,8 +68,8 @@ class ScalingFactorCntl(max_block: Int) extends Bundle {
6868
val counter_b = UInt(log2Up(max_block).W)
6969
val fire_a = Bool()
7070
val fire_b = Bool()
71-
val baseAddress_act = UInt(33.W)
72-
val baseAddress_w = UInt(33.W)
71+
val baseAddress_act = UInt(32.W)
72+
val baseAddress_w = UInt(32.W)
7373
}
7474

7575
class RequantizerInBundle(numLanes: Int, dataWidth: Int = 16) extends Bundle {

src/main/scala/gemmini/QuantLut.scala

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,25 +58,28 @@ class QuantLut(
5858
val lutCache_act_in_buffer_1_read_enable = RegInit(false.B)
5959
val lutCache_act_in_buffer_select = RegInit(false.B)
6060
val counter_i_reg = RegNext(io.counter_i)
61+
val counter_w = RegInit(0.U(5.W))
62+
val counter_act = RegInit(0.U(5.W))
63+
6164
when(io.lut_write_act_in.fire){
6265
when(lutCache_act_in_flag === false.B){
6366
for (lane <- 0 until 32) {
6467
for (entry <- 0 until 16) {
65-
//lutCache_act_in_0(lane)(entry) := io.lut_write_act_in.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
66-
lutCache_act_in_0(lane)(entry) := io.lut_write_act_in.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
67-
}
68+
lutCache_act_in_0(lane)(entry) := io.lut_write_act_in.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
69+
//lutCache_act_in_0(lane)(entry) := io.lut_write_act_in.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
70+
}
71+
}
6872
lutCache_act_in_flag := ~lutCache_act_in_flag
6973
lutCache_act_in_buffer_0_read_enable := true.B
70-
}
7174
}.otherwise {
7275
for (lane <- 0 until 32) {
7376
for (entry <- 0 until 16) {
74-
//lutCache_act_in_1(lane)(entry) := io.lut_write_act_in.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
75-
lutCache_act_in_1(lane)(entry) := io.lut_write_act_in.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
77+
lutCache_act_in_1(lane)(entry) := io.lut_write_act_in.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
78+
//lutCache_act_in_1(lane)(entry) := io.lut_write_act_in.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
7679
}
80+
}
7781
lutCache_act_in_flag := ~lutCache_act_in_flag
7882
lutCache_act_in_buffer_1_read_enable := true.B
79-
}
8083
}
8184
}
8285

@@ -111,17 +114,17 @@ class QuantLut(
111114
when(lutCache_weight_flag === false.B){
112115
for (lane <- 0 until 32) {
113116
for (entry <- 0 until 16) {
114-
//lutCache_weight_0(lane)(entry) := io.lut_write_weight.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
115-
lutCache_weight_0(lane)(entry) := io.lut_write_weight.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
117+
lutCache_weight_0(lane)(entry) := io.lut_write_weight.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
118+
//lutCache_weight_0(lane)(entry) := io.lut_write_weight.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
116119
}
117120
}
118121
lutCache_weight_flag := ~lutCache_weight_flag
119122
lutCache_weight_buffer_0_read_enable := true.B
120123
}.otherwise {
121124
for (lane <- 0 until 32) {
122125
for (entry <- 0 until 16) {
123-
//lutCache_weight_1(lane)(entry) := io.lut_write_weight.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
124-
lutCache_weight_1(lane)(entry) := io.lut_write_weight.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
126+
lutCache_weight_1(lane)(entry) := io.lut_write_weight.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
127+
//lutCache_weight_1(lane)(entry) := io.lut_write_weight.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
125128
}
126129
}
127130
lutCache_weight_flag := ~lutCache_weight_flag
@@ -159,17 +162,17 @@ class QuantLut(
159162
when(lutCache_act_out_flag === false.B){
160163
for (lane <- 0 until 32) {
161164
for (entry <- 0 until 16) {
162-
//lutCache_act_out_0(lane)(entry) := io.lut_write_act_out.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
163-
lutCache_act_out_0(lane)(entry) := io.lut_write_act_out.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
165+
lutCache_act_out_0(lane)(entry) := io.lut_write_act_out.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
166+
//lutCache_act_out_0(lane)(entry) := io.lut_write_act_out.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
164167
}
165168
}
166169
lutCache_act_out_flag := ~lutCache_act_out_flag
167170
lutCache_act_out_buffer_0_read_enable := true.B
168171
}.otherwise {
169172
for (lane <- 0 until 32) {
170173
for (entry <- 0 until 16) {
171-
//lutCache_act_out_1(lane)(entry) := io.lut_write_act_out.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
172-
lutCache_act_out_1(lane)(entry) := io.lut_write_act_out.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
174+
lutCache_act_out_1(lane)(entry) := io.lut_write_act_out.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
175+
//lutCache_act_out_1(lane)(entry) := io.lut_write_act_out.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
173176
}
174177
}
175178
lutCache_act_out_flag := ~lutCache_act_out_flag
@@ -212,7 +215,7 @@ class QuantLut(
212215
})
213216

214217
val minIdx = distances.zipWithIndex.map { case (dist, idx) =>
215-
(dist, idx.U(5.W))
218+
(dist, idx.U(raddrWidth.W))
216219
}.reduce { (a, b) =>
217220
val selectA = a._1 <= b._1
218221
(Mux(selectA, a._1, b._1), Mux(selectA, a._2, b._2))
@@ -243,12 +246,22 @@ class QuantLut(
243246
when(io.a_fire && (lutCache_act_in_buffer_0_read_enable || lutCache_act_in_buffer_1_read_enable)) {
244247
for (k <- 0 until 32) {
245248
val chunk_4bit = io.spad_projected_data(i).resp.bits.data((k+1)*4-1, k*4)
246-
deprojected_bits(k) := lutCache_act_in(k)(chunk_4bit)
249+
deprojected_bits(k) := lutCache_act_in(counter_act)(chunk_4bit)
250+
}
251+
when(counter_act === 31.U){
252+
counter_act := 0.U
253+
}.otherwise{
254+
counter_act := counter_act + 1.U
247255
}
248256
}.elsewhen(io.b_fire && (lutCache_weight_buffer_0_read_enable || lutCache_weight_buffer_1_read_enable)) {
249257
for (k <- 0 until 32) {
250258
val chunk_4bit = io.spad_projected_data(i).resp.bits.data((k+1)*4-1, k*4)
251-
deprojected_bits(k) := lutCache_weight(k)(chunk_4bit)
259+
deprojected_bits(k) := lutCache_weight(counter_w)(chunk_4bit)
260+
}
261+
when(counter_w === 31.U){
262+
counter_w := 0.U
263+
}.otherwise{
264+
counter_w := counter_w + 1.U
252265
}
253266
}
254267

src/main/scala/gemmini/ScaleFactorMem.scala

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ class ScalingFactorMemIO(addrWidth: Int, dataWidth: Int, numRows: Int, numCols:
2424
val counter_i = Input(UInt(16.W))
2525
val counter_j = Input(UInt(16.W))
2626
val counter_k = Input(UInt(16.W))
27+
val i = Input(UInt(16.W))
28+
val j = Input(UInt(16.W))
29+
val k = Input(UInt(16.W))
2730
}
2831

2932
class ScalingFactorMem(
@@ -52,12 +55,6 @@ class ScalingFactorMem(
5255
tileRows,
5356
))
5457

55-
val counter_a_fire = io.scaleMemCntl.counter_a
56-
val counter_b_fire = io.scaleMemCntl.counter_b
57-
val fire_a = io.scaleMemCntl.fire_a
58-
val fire_b = io.scaleMemCntl.fire_b
59-
val write_baseAddr_act = io.scaleMemCntl.baseAddress_act
60-
val write_baseAddr_w = io.scaleMemCntl.baseAddress_w
6158
val initByte = 0x7e.U(8.W)
6259
val defaultRow = VecInit(Seq.fill(bytesPerBank)(initByte))
6360
val counter_i = io.counter_i
@@ -115,7 +112,6 @@ class ScalingFactorMem(
115112
val write_bytes_low = write_weight_full_row(bytesPerBank * 8 - 1, 0).asTypeOf(bankDataT)
116113
val write_bytes_high = write_weight_full_row(bytesPerBank * 2 * 8 - 1, bytesPerBank * 8).asTypeOf(bankDataT)
117114
when(weight_write_buffer_sel === false.B) {
118-
//weight_write_counter := weight_write_counter + 1.U
119115
banks(4).write(write_row_addr_w_reg, write_bytes_low)
120116
banks(5).write(write_row_addr_w_reg, write_bytes_high)
121117
weight_buffer_0_read_enable := true.B
@@ -183,9 +179,10 @@ class ScalingFactorMem(
183179

184180
val max_block_fp8 = meshRows * tileRows
185181
val max_block_non_fp8 = 2*meshRows * tileRows
186-
val read_row_addr = WireDefault(counter_k >> log2Ceil(max_block_non_fp8))
187-
io.scale_mem_write_w.ready := (weight_write_counter ===0.U || (weight_write_counter(log2Ceil(depth)-1,0) =/= read_row_addr)) || (!weight_buffer_0_read_enable) || (!weight_buffer_1_read_enable)
188-
io.scale_mem_write_act.ready := (act_write_counter ===0.U || ((act_write_counter(log2Ceil(depth)-1,0) =/= read_row_addr))) || (!act_buffer_0_read_enable) || (!act_buffer_1_read_enable)
182+
val read_row_addr_act = WireDefault(io.i * (counter_k >> log2Ceil(max_block_non_fp8)) + (counter_i >> (log2Ceil(bytesPerBank*numBanks/2))))
183+
val read_row_addr_w = WireDefault(io.j * (counter_k >> log2Ceil(max_block_non_fp8)) + (counter_j >> (log2Ceil(bytesPerBank*numBanks/2))))
184+
io.scale_mem_write_w.ready := (weight_write_counter ===0.U || (weight_write_counter(log2Ceil(depth)-1,0) =/= read_row_addr_w)) || (!weight_buffer_0_read_enable) || (!weight_buffer_1_read_enable)
185+
io.scale_mem_write_act.ready := (act_write_counter ===0.U || ((act_write_counter(log2Ceil(depth)-1,0) =/= read_row_addr_act))) || (!act_buffer_0_read_enable) || (!act_buffer_1_read_enable)
189186
val act_read_buffer_select = RegInit(false.B)
190187
val weight_read_buffer_select = RegInit(false.B)
191188
val act_read_counter = RegInit(0.U(8.W))
@@ -194,17 +191,17 @@ class ScalingFactorMem(
194191
when(io.read_req.fire && io.read_req.bits.scaling_enable){
195192
act_read_buffer_select := ~act_read_buffer_select
196193
weight_read_buffer_select := ~weight_read_buffer_select
197-
when(act_buffer_0_read_enable && ((act_write_counter(log2Ceil(depth)-1,0) === read_row_addr))){
194+
when(act_buffer_0_read_enable && ((act_write_counter(log2Ceil(depth)-1,0) === read_row_addr_act))){
198195
act_buffer_0_read_enable := false.B
199196
}
200-
when(act_buffer_1_read_enable && ((act_write_counter(log2Ceil(depth)-1,0) === read_row_addr))){
197+
when(act_buffer_1_read_enable && ((act_write_counter(log2Ceil(depth)-1,0) === read_row_addr_act))){
201198
act_buffer_1_read_enable := false.B
202199
}
203200

204-
when(weight_buffer_0_read_enable && ((weight_write_counter(log2Ceil(depth)-1,0) === read_row_addr))){
201+
when(weight_buffer_0_read_enable && ((weight_write_counter(log2Ceil(depth)-1,0) === read_row_addr_w))){
205202
weight_buffer_0_read_enable := false.B
206203
}
207-
when(weight_buffer_1_read_enable && ((weight_write_counter(log2Ceil(depth)-1,0) === read_row_addr))){
204+
when(weight_buffer_1_read_enable && ((weight_write_counter(log2Ceil(depth)-1,0) === read_row_addr_w))){
208205
weight_buffer_1_read_enable := false.B
209206
}
210207
}
@@ -234,79 +231,79 @@ class ScalingFactorMem(
234231
weight_bank_sel := 0.U
235232

236233
val read_fire_banks = VecInit(Seq(
237-
read_fire_real && act_buffer_0_read_enable && weight_buffer_0_read_enable && (act_bank_sel === 0.U) , // bank 0
238-
read_fire_real && act_buffer_0_read_enable && weight_buffer_0_read_enable && (act_bank_sel === 1.U), // bank 1
239-
read_fire_real && act_buffer_1_read_enable && weight_buffer_1_read_enable && (act_bank_sel === 2.U), // bank 2
240-
read_fire_real && act_buffer_1_read_enable && weight_buffer_1_read_enable && (act_bank_sel === 3.U), // bank 3
241-
read_fire_real && weight_buffer_0_read_enable && act_buffer_0_read_enable && (weight_bank_sel === 0.U), // bank 4
242-
read_fire_real && weight_buffer_0_read_enable && act_buffer_0_read_enable && (weight_bank_sel === 1.U), // bank 5
243-
read_fire_real && weight_buffer_1_read_enable && weight_buffer_1_read_enable && (weight_bank_sel === 2.U), // bank 6
244-
read_fire_real && weight_buffer_1_read_enable && weight_buffer_1_read_enable && (weight_bank_sel === 3.U) // bank 7
234+
read_fire_real && act_buffer_0_read_enable && (act_bank_sel === 0.U) , // bank 0
235+
read_fire_real && act_buffer_0_read_enable && (act_bank_sel === 1.U), // bank 1
236+
read_fire_real && act_buffer_1_read_enable && (act_bank_sel === 2.U), // bank 2
237+
read_fire_real && act_buffer_1_read_enable && (act_bank_sel === 3.U), // bank 3
238+
read_fire_real && act_buffer_0_read_enable && (weight_bank_sel === 0.U), // bank 4
239+
read_fire_real && act_buffer_0_read_enable && (weight_bank_sel === 1.U), // bank 5
240+
read_fire_real && weight_buffer_1_read_enable && (weight_bank_sel === 2.U), // bank 6
241+
read_fire_real && weight_buffer_1_read_enable && (weight_bank_sel === 3.U) // bank 7
245242
))
246243

247244

248-
val bank_data = VecInit((0 until 8).map { i => if (testConfig) defaultRow else banks(i).read(read_row_addr, read_fire_banks(i))})
249-
245+
val bank_data_0 = VecInit((0 until 4).map { i => if (testConfig) defaultRow else banks(i).read(read_row_addr_act, read_fire_banks(i))})
246+
val bank_data_1 = VecInit((0 until 4).map { i => if (testConfig) defaultRow else banks(i+4).read(read_row_addr_w, read_fire_banks(i+4))})
250247
when(fp8Mode){
251248
act_bank_sel := counter_i(1+log2Ceil(max_block_fp8), log2Ceil(max_block_fp8))
252249
weight_bank_sel := counter_j(1+log2Ceil(max_block_fp8), log2Ceil(max_block_fp8))
253250
when(act_bank_sel === 0.U && (act_buffer_0_read_enable)) {
254251
for (i <- 0 until meshRows*tileRows) {
255-
act_bank_data_vec(i) := bank_data(0)(i)
252+
act_bank_data_vec(i) := bank_data_0(0)(i)
256253
}
257254
}.elsewhen( act_bank_sel === 1.U && (act_buffer_0_read_enable)) {
258255
for (i <- 0 until meshRows*tileRows) {
259-
act_bank_data_vec(i) := bank_data(1)(i)
256+
act_bank_data_vec(i) := bank_data_0(1)(i)
260257
}
261258
}.elsewhen( act_bank_sel === 2.U && (act_buffer_1_read_enable)) {
262259
for (i <- 0 until meshRows*tileRows) {
263-
act_bank_data_vec(i) := bank_data(2)(i)
260+
act_bank_data_vec(i) := bank_data_0(2)(i)
264261
}
265262
}.elsewhen( act_bank_sel === 3.U && (act_buffer_1_read_enable)) {
266263
for (i <- 0 until meshRows*tileRows) {
267-
act_bank_data_vec(i) := bank_data(3)(i)
264+
act_bank_data_vec(i) := bank_data_0(3)(i)
268265
}
269266
}
270267
when(weight_bank_sel === 0.U && (weight_buffer_0_read_enable)) {
271268
for (i <- 0 until meshRows*tileRows) {
272-
weight_bank_data_vec(i) := bank_data(4)(i)
269+
weight_bank_data_vec(i) := bank_data_1(0)(i)
273270
}
274271
}.elsewhen( weight_bank_sel === 1.U && (weight_buffer_0_read_enable)) {
275272
for (i <- 0 until meshRows*tileRows) {
276-
weight_bank_data_vec(i) := bank_data(5)(i)
273+
weight_bank_data_vec(i) := bank_data_1(1)(i)
277274
}
278275
}.elsewhen( weight_bank_sel === 2.U && (weight_buffer_1_read_enable)) {
279276
for (i <- 0 until meshRows*tileRows) {
280-
weight_bank_data_vec(i) := bank_data(6)(i)
277+
weight_bank_data_vec(i) := bank_data_1(2)(i)
281278
}
282279
}.elsewhen( weight_bank_sel === 3.U && (weight_buffer_1_read_enable)) {
283280
for (i <- 0 until meshRows*tileRows) {
284-
weight_bank_data_vec(i) := bank_data(7)(i)
281+
weight_bank_data_vec(i) := bank_data_1(3)(i)
285282
}
286283
}
287284
}.otherwise{
288285
act_bank_sel := Cat(0.U(1.W), counter_i(log2Ceil(max_block_non_fp8)))
289286
weight_bank_sel := Cat(0.U(1.W),counter_j(log2Ceil(max_block_non_fp8)))
290-
when(act_bank_sel(0) === 0.U && (act_buffer_0_read_enable) && (weight_buffer_0_read_enable)) {
287+
when(act_bank_sel(0) === 0.U && (act_buffer_0_read_enable)) {
291288
for (i <- 0 until meshRows*tileRows) {
292-
act_bank_data_vec(i) := bank_data(0)(i)
293-
act_bank_data_vec(meshRows*tileRows+i) := bank_data(1)(i)
289+
act_bank_data_vec(i) := bank_data_0(0)(i)
290+
act_bank_data_vec(meshRows*tileRows+i) := bank_data_0(1)(i)
294291
}
295-
}.elsewhen(act_bank_sel(0) === 1.U && (act_buffer_0_read_enable) && (weight_buffer_0_read_enable)) {
292+
}.elsewhen(act_bank_sel(0) === 1.U && (act_buffer_0_read_enable) ) {
296293
for (i <- 0 until meshRows*tileRows) {
297-
act_bank_data_vec(i) := bank_data(2)(i)
298-
act_bank_data_vec(meshRows*tileRows+i) := bank_data(3)(i)
294+
act_bank_data_vec(i) := bank_data_0(2)(i)
295+
act_bank_data_vec(meshRows*tileRows+i) := bank_data_0(3)(i)
299296
}
300297
}
301-
when(counter_j(1+log2Ceil(max_block_non_fp8), log2Ceil(max_block_non_fp8)) === 1.U && (weight_buffer_0_read_enable)) {
298+
when(weight_bank_sel === 0.U && (weight_buffer_0_read_enable)) {
302299
for (i <- 0 until meshRows*tileRows) {
303-
weight_bank_data_vec(i) := bank_data(4)(i)
304-
weight_bank_data_vec(meshRows*tileRows+i) := bank_data(5)(i)
300+
weight_bank_data_vec(i) := bank_data_1(0)(i)
301+
weight_bank_data_vec(meshRows*tileRows+i) := bank_data_1(1)(i)
305302
}
306-
}.elsewhen( counter_j(1+log2Ceil(max_block_non_fp8), log2Ceil(max_block_non_fp8)) === 1.U && (weight_buffer_1_read_enable)) {
303+
}.elsewhen( weight_bank_sel === 1.U && (weight_buffer_1_read_enable)) {
307304
for (i <- 0 until meshRows*tileRows) {
308-
weight_bank_data_vec(i) := bank_data(6)(i)
309-
weight_bank_data_vec(meshRows*tileRows+i) := bank_data(7)(i)
305+
weight_bank_data_vec(i) := bank_data_1(2)(i)
306+
weight_bank_data_vec(meshRows*tileRows+i) := bank_data_1(3)(i)
310307
}
311308
}
312309
}

0 commit comments

Comments
 (0)