@@ -44,15 +44,15 @@ class QuantLut(
4444 val io = IO (new QuantLutIO (lutConfig, outputnumLanes, sp_bank_entries, sp_banks, sp_width, sp_width_projected, iterator_bitwidth))
4545 val rdataWidth = lutConfig.rdataWidth
4646 val raddrWidth = lutConfig.raddrWidth
47- val lutCache_weight_0 = Seq .fill(32 )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
48- val lutCache_weight_1 = Seq .fill(32 )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
49- val lutCache_act_in_0 = Seq .fill(32 )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
50- val lutCache_act_in_1 = Seq .fill(32 )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
51- val lutCache_act_out_0 = Seq .fill(32 )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
52- val lutCache_act_out_1 = Seq .fill(32 )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
47+ val lutCache_weight_0 = Seq .fill(lutConfig.numEntries )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
48+ val lutCache_weight_1 = Seq .fill(lutConfig.numEntries )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
49+ val lutCache_act_in_0 = Seq .fill(lutConfig.numEntries )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
50+ val lutCache_act_in_1 = Seq .fill(lutConfig.numEntries )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
51+ val lutCache_act_out_0 = Seq .fill(outputnumLanes )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
52+ val lutCache_act_out_1 = Seq .fill(outputnumLanes )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
5353 // io.lut_write.ready := !io.lutReadEnable
5454
55- val lutCache_act_in = WireInit (VecInit .fill(32 )(VecInit .fill(16 )(0 .U (rdataWidth.W ))))
55+ val lutCache_act_in = WireInit (VecInit .fill(lutConfig.numEntries )(VecInit .fill(16 )(0 .U (rdataWidth.W ))))
5656 val lutCache_act_in_flag = RegInit (false .B )
5757 val lutCache_act_in_buffer_0_read_enable = RegInit (false .B )
5858 val lutCache_act_in_buffer_1_read_enable = RegInit (false .B )
@@ -63,7 +63,7 @@ class QuantLut(
6363
6464 when(io.lut_write_act_in.fire){
6565 when(lutCache_act_in_flag === false .B ){
66- for (lane <- 0 until 32 ) {
66+ for (lane <- 0 until lutConfig.numEntries ) {
6767 for (entry <- 0 until 16 ) {
6868 lutCache_act_in_0(lane)(entry) := io.lut_write_act_in.bits.data(lane)((entry+ 1 )* rdataWidth- 1 , entry* rdataWidth)
6969 // lutCache_act_in_0(lane)(entry) := io.lut_write_act_in.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -72,7 +72,7 @@ class QuantLut(
7272 lutCache_act_in_flag := ~ lutCache_act_in_flag
7373 lutCache_act_in_buffer_0_read_enable := true .B
7474 }.otherwise {
75- for (lane <- 0 until 32 ) {
75+ for (lane <- 0 until lutConfig.numEntries ) {
7676 for (entry <- 0 until 16 ) {
7777 lutCache_act_in_1(lane)(entry) := io.lut_write_act_in.bits.data(lane)((entry+ 1 )* rdataWidth- 1 , entry* rdataWidth)
7878 // lutCache_act_in_1(lane)(entry) := io.lut_write_act_in.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -86,9 +86,9 @@ class QuantLut(
8686
8787 val lutCache_update_enable_act_in = WireInit (0 .U (1 .W ))
8888 // FP6 only: tile = 32 elements, period = regularity/32 tiles
89- lutCache_update_enable_act_in := (counter_i(log2Ceil(lut_update_regularity_act_in/ 32 )- 1 , 0 ) === 0 .U ) && (counter_i_reg(log2Ceil(lut_update_regularity_act_in/ 32 )- 1 , 0 ) === (lut_update_regularity_act_in/ 32 - 1 ).U )
89+ lutCache_update_enable_act_in := (io. counter_i(log2Ceil(lut_update_regularity_act_in/ 32 )- 1 , 0 ) === 0 .U ) && (counter_i_reg(log2Ceil(lut_update_regularity_act_in/ 32 )- 1 , 0 ) === (lut_update_regularity_act_in/ 32 - 1 ).U )
9090
91- when(lutCache_update_enable_act_in){ // 32 is the maxblock under fp6
91+ when(lutCache_update_enable_act_in === 1 . U ){ // 32 is the maxblock under fp6
9292 when(lutCache_act_in_buffer_0_read_enable && (lutCache_act_in_buffer_select === false .B )){
9393 lutCache_act_in_buffer_0_read_enable := false .B
9494 }
@@ -108,7 +108,7 @@ class QuantLut(
108108
109109 io.lut_write_act_in.ready := ! lutCache_act_in_buffer_0_read_enable || ! lutCache_act_in_buffer_1_read_enable
110110
111- val lutCache_weight = WireInit (VecInit .fill(32 )(VecInit .fill(16 )(0 .U (rdataWidth.W ))))
111+ val lutCache_weight = WireInit (VecInit .fill(lutConfig.numEntries )(VecInit .fill(16 )(0 .U (rdataWidth.W ))))
112112 val lutCache_weight_flag = RegInit (false .B )
113113 val lutCache_weight_buffer_0_read_enable = RegInit (false .B )
114114 val lutCache_weight_buffer_1_read_enable = RegInit (false .B )
@@ -117,7 +117,7 @@ class QuantLut(
117117
118118 when(io.lut_write_weight.fire){
119119 when(lutCache_weight_flag === false .B ){
120- for (lane <- 0 until 32 ) {
120+ for (lane <- 0 until lutConfig.numEntries ) {
121121 for (entry <- 0 until 16 ) {
122122 lutCache_weight_0(lane)(entry) := io.lut_write_weight.bits.data(lane)((entry+ 1 )* rdataWidth- 1 , entry* rdataWidth)
123123 // lutCache_weight_0(lane)(entry) := io.lut_write_weight.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -126,7 +126,7 @@ class QuantLut(
126126 lutCache_weight_flag := ~ lutCache_weight_flag
127127 lutCache_weight_buffer_0_read_enable := true .B
128128 }.otherwise {
129- for (lane <- 0 until 32 ) {
129+ for (lane <- 0 until lutConfig.numEntries ) {
130130 for (entry <- 0 until 16 ) {
131131 lutCache_weight_1(lane)(entry) := io.lut_write_weight.bits.data(lane)((entry+ 1 )* rdataWidth- 1 , entry* rdataWidth)
132132 // lutCache_weight_1(lane)(entry) := io.lut_write_weight.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -139,9 +139,9 @@ class QuantLut(
139139
140140 val lutCache_update_enable_w_in = WireInit (0 .U (1 .W ))
141141 // FP6 only: tile = 32 elements, period = regularity_w/32 tiles
142- lutCache_update_enable_w_in := (counter_j(log2Ceil(lut_update_regularity_w/ 32 )- 1 , 0 ) === 0 .U ) && (counter_j_reg(log2Ceil(lut_update_regularity_w/ 32 )- 1 , 0 ) === (lut_update_regularity_w/ 32 - 1 ).U )
142+ lutCache_update_enable_w_in := (io. counter_j(log2Ceil(lut_update_regularity_w/ 32 )- 1 , 0 ) === 0 .U ) && (counter_j_reg(log2Ceil(lut_update_regularity_w/ 32 )- 1 , 0 ) === (lut_update_regularity_w/ 32 - 1 ).U )
143143
144- when(lutCache_update_enable_w_in){
144+ when(lutCache_update_enable_w_in === 1 . U ){
145145 when(lutCache_weight_buffer_0_read_enable && (lutCache_weight_buffer_select === false .B )){
146146 lutCache_weight_buffer_0_read_enable := false .B
147147 }
@@ -161,15 +161,15 @@ class QuantLut(
161161
162162 io.lut_write_weight.ready := ! lutCache_weight_buffer_0_read_enable || ! lutCache_weight_buffer_1_read_enable
163163
164- val lutCache_act_out = WireInit (VecInit .fill(32 )(VecInit .fill(16 )(0 .U (rdataWidth.W ))))
164+ val lutCache_act_out = WireInit (VecInit .fill(outputnumLanes )(VecInit .fill(16 )(0 .U (rdataWidth.W ))))
165165 val lutCache_act_out_flag = RegInit (false .B )
166166 val lutCache_act_out_buffer_0_read_enable = RegInit (false .B )
167167 val lutCache_act_out_buffer_1_read_enable = RegInit (false .B )
168168 val lutCache_act_out_buffer_select = RegInit (false .B )
169169
170170 when(io.lut_write_act_out.fire){
171171 when(lutCache_act_out_flag === false .B ){
172- for (lane <- 0 until 32 ) {
172+ for (lane <- 0 until lutConfig.numEntries ) {
173173 for (entry <- 0 until 16 ) {
174174 lutCache_act_out_0(lane)(entry) := io.lut_write_act_out.bits.data(lane)((entry+ 1 )* rdataWidth- 1 , entry* rdataWidth)
175175 // lutCache_act_out_0(lane)(entry) := io.lut_write_act_out.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -178,7 +178,7 @@ class QuantLut(
178178 lutCache_act_out_flag := ~ lutCache_act_out_flag
179179 lutCache_act_out_buffer_0_read_enable := true .B
180180 }.otherwise {
181- for (lane <- 0 until 32 ) {
181+ for (lane <- 0 until lutConfig.numEntries ) {
182182 for (entry <- 0 until 16 ) {
183183 lutCache_act_out_1(lane)(entry) := io.lut_write_act_out.bits.data(lane)((entry+ 1 )* rdataWidth- 1 , entry* rdataWidth)
184184 // lutCache_act_out_1(lane)(entry) := io.lut_write_act_out.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -191,9 +191,9 @@ class QuantLut(
191191
192192 val lutCache_update_enable_act_out = WireInit (0 .U (1 .W ))
193193 // FP6 only: tile = 32 elements, period = regularity_act_out/32 tiles
194- lutCache_update_enable_act_out := (counter_i(log2Ceil(lut_update_regularity_act_out/ 32 )- 1 , 0 ) === 0 .U ) && (counter_i_reg(log2Ceil(lut_update_regularity_act_out/ 32 )- 1 , 0 ) === (lut_update_regularity_act_out/ 32 - 1 ).U )
194+ lutCache_update_enable_act_out := (io. counter_i(log2Ceil(lut_update_regularity_act_out/ 32 )- 1 , 0 ) === 0 .U ) && (counter_i_reg(log2Ceil(lut_update_regularity_act_out/ 32 )- 1 , 0 ) === (lut_update_regularity_act_out/ 32 - 1 ).U )
195195
196- when(lutCache_update_enable_act_out){
196+ when(lutCache_update_enable_act_out === 1 . U ){
197197 when(lutCache_act_out_buffer_0_read_enable && (lutCache_act_out_buffer_select === false .B )){
198198 lutCache_act_out_buffer_0_read_enable := false .B
199199 }
@@ -261,7 +261,7 @@ class QuantLut(
261261 val chunk_4bit = io.spad_projected_data(i).resp.bits.data((k+ 1 )* 4 - 1 , k* 4 )
262262 deprojected_bits(k) := lutCache_act_in(counter_act)(chunk_4bit)
263263 }
264- when(counter_act === 31 .U ){
264+ when(counter_act === (lutConfig.numEntries - 1 ) .U ){
265265 counter_act := 0 .U
266266 }.otherwise{
267267 counter_act := counter_act + 1 .U
@@ -271,7 +271,7 @@ class QuantLut(
271271 val chunk_4bit = io.spad_projected_data(i).resp.bits.data((k+ 1 )* 4 - 1 , k* 4 )
272272 deprojected_bits(k) := lutCache_weight(counter_w)(chunk_4bit)
273273 }
274- when(counter_w === 31 .U ){
274+ when(counter_w === (lutConfig.numEntries - 1 ) .U ){
275275 counter_w := 0 .U
276276 }.otherwise{
277277 counter_w := counter_w + 1 .U
0 commit comments