@@ -44,15 +44,15 @@ class QuantLut(
4444 val io = IO (new QuantLutIO (lutConfig, outputnumLanes, sp_bank_entries, sp_banks, sp_width, sp_width_projected, iterator_bitwidth))
4545 val rdataWidth = lutConfig.rdataWidth
4646 val raddrWidth = lutConfig.raddrWidth
47- val lutCache_weight_0 = Seq .fill(lutConfig.numEntries )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
48- val lutCache_weight_1 = Seq .fill(lutConfig.numEntries )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
49- val lutCache_act_in_0 = Seq .fill(lutConfig.numEntries )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
50- val lutCache_act_in_1 = Seq .fill(lutConfig.numEntries )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
51- val lutCache_act_out_0 = Seq .fill(outputnumLanes )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
52- val lutCache_act_out_1 = Seq .fill(outputnumLanes )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
47+ val lutCache_weight_0 = Seq .fill(lutConfig( 0 )._1 )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
48+ val lutCache_weight_1 = Seq .fill(lutConfig( 0 )._1 )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
49+ val lutCache_act_in_0 = Seq .fill(lutConfig( 1 )._1 )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
50+ val lutCache_act_in_1 = Seq .fill(lutConfig( 1 )._1 )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
51+ val lutCache_act_out_0 = Seq .fill(lutConfig( 2 )._1 )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
52+ val lutCache_act_out_1 = Seq .fill(lutConfig( 2 )._1 )(RegInit (VecInit (Seq .fill(16 )(0 .U (rdataWidth.W )))))
5353 // io.lut_write.ready := !io.lutReadEnable
5454
55- val lutCache_act_in = WireInit (VecInit .fill(lutConfig.numEntries )(VecInit .fill(16 )(0 .U (rdataWidth.W ))))
55+ val lutCache_act_in = WireInit (VecInit .fill(lutConfig( 0 )._1 )(VecInit .fill(16 )(0 .U (rdataWidth.W ))))
5656 val lutCache_act_in_flag = RegInit (false .B )
5757 val lutCache_act_in_buffer_0_read_enable = RegInit (false .B )
5858 val lutCache_act_in_buffer_1_read_enable = RegInit (false .B )
@@ -63,7 +63,7 @@ class QuantLut(
6363
6464 when(io.lut_write_act_in.fire){
6565 when(lutCache_act_in_flag === false .B ){
66- for (lane <- 0 until lutConfig.numEntries ) {
66+ for (lane <- 0 until lutConfig( 0 )._1 ) {
6767 for (entry <- 0 until 16 ) {
6868 lutCache_act_in_0(lane)(entry) := io.lut_write_act_in.bits.data(lane)((entry+ 1 )* rdataWidth- 1 , entry* rdataWidth)
6969 // lutCache_act_in_0(lane)(entry) := io.lut_write_act_in.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -72,7 +72,7 @@ class QuantLut(
7272 lutCache_act_in_flag := ~ lutCache_act_in_flag
7373 lutCache_act_in_buffer_0_read_enable := true .B
7474 }.otherwise {
75- for (lane <- 0 until lutConfig.numEntries ) {
75+ for (lane <- 0 until lutConfig( 0 )._1 ) {
7676 for (entry <- 0 until 16 ) {
7777 lutCache_act_in_1(lane)(entry) := io.lut_write_act_in.bits.data(lane)((entry+ 1 )* rdataWidth- 1 , entry* rdataWidth)
7878 // lutCache_act_in_1(lane)(entry) := io.lut_write_act_in.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -108,7 +108,7 @@ class QuantLut(
108108
109109 io.lut_write_act_in.ready := ! lutCache_act_in_buffer_0_read_enable || ! lutCache_act_in_buffer_1_read_enable
110110
111- val lutCache_weight = WireInit (VecInit .fill(lutConfig.numEntries )(VecInit .fill(16 )(0 .U (rdataWidth.W ))))
111+ val lutCache_weight = WireInit (VecInit .fill(lutConfig( 1 )._1 )(VecInit .fill(16 )(0 .U (rdataWidth.W ))))
112112 val lutCache_weight_flag = RegInit (false .B )
113113 val lutCache_weight_buffer_0_read_enable = RegInit (false .B )
114114 val lutCache_weight_buffer_1_read_enable = RegInit (false .B )
@@ -117,7 +117,7 @@ class QuantLut(
117117
118118 when(io.lut_write_weight.fire){
119119 when(lutCache_weight_flag === false .B ){
120- for (lane <- 0 until lutConfig.numEntries ) {
120+ for (lane <- 0 until lutConfig( 1 )._1 ) {
121121 for (entry <- 0 until 16 ) {
122122 lutCache_weight_0(lane)(entry) := io.lut_write_weight.bits.data(lane)((entry+ 1 )* rdataWidth- 1 , entry* rdataWidth)
123123 // lutCache_weight_0(lane)(entry) := io.lut_write_weight.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -126,7 +126,7 @@ class QuantLut(
126126 lutCache_weight_flag := ~ lutCache_weight_flag
127127 lutCache_weight_buffer_0_read_enable := true .B
128128 }.otherwise {
129- for (lane <- 0 until lutConfig.numEntries ) {
129+ for (lane <- 0 until lutConfig( 1 )._1 ) {
130130 for (entry <- 0 until 16 ) {
131131 lutCache_weight_1(lane)(entry) := io.lut_write_weight.bits.data(lane)((entry+ 1 )* rdataWidth- 1 , entry* rdataWidth)
132132 // lutCache_weight_1(lane)(entry) := io.lut_write_weight.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -169,7 +169,7 @@ class QuantLut(
169169
170170 when(io.lut_write_act_out.fire){
171171 when(lutCache_act_out_flag === false .B ){
172- for (lane <- 0 until lutConfig.numEntries ) {
172+ for (lane <- 0 until lutConfig( 2 )._1 ) {
173173 for (entry <- 0 until 16 ) {
174174 lutCache_act_out_0(lane)(entry) := io.lut_write_act_out.bits.data(lane)((entry+ 1 )* rdataWidth- 1 , entry* rdataWidth)
175175 // lutCache_act_out_0(lane)(entry) := io.lut_write_act_out.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -178,7 +178,7 @@ class QuantLut(
178178 lutCache_act_out_flag := ~ lutCache_act_out_flag
179179 lutCache_act_out_buffer_0_read_enable := true .B
180180 }.otherwise {
181- for (lane <- 0 until lutConfig.numEntries ) {
181+ for (lane <- 0 until lutConfig( 2 )._1 ) {
182182 for (entry <- 0 until 16 ) {
183183 lutCache_act_out_1(lane)(entry) := io.lut_write_act_out.bits.data(lane)((entry+ 1 )* rdataWidth- 1 , entry* rdataWidth)
184184 // lutCache_act_out_1(lane)(entry) := io.lut_write_act_out.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -213,17 +213,18 @@ class QuantLut(
213213
214214 io.lut_write_act_out.ready := ! lutCache_act_out_buffer_0_read_enable || ! lutCache_act_out_buffer_1_read_enable
215215
216- val projectedIndices = RegInit (VecInit (Seq .fill(outputnumLanes )(0 .U (raddrWidth.W ))))
216+ val projectedIndices = RegInit (VecInit (Seq .fill(lutConfig( 2 )._1 )(0 .U (raddrWidth.W ))))
217217 val projectedDataValid = RegInit (false .B )
218-
218+ val counter_act_out = RegInit ( 0 . U (log2Ceil(lutConfig( 2 )._1). W ))
219219 // TODO: double check if this is the correct way to do nearest neighbor search, aligning with the algorithm implementation
220220 when((lutCache_act_out_buffer_0_read_enable || lutCache_act_out_buffer_1_read_enable) && io.quant_fp6.valid) {
221- for (i <- 0 until outputnumLanes) {
221+
222+ for (i <- 0 until lutConfig(2 )._1) {
222223 val inputFp6 = io.quant_fp6.bits(i)
223224 val distances = VecInit ((0 until 16 ).map { j =>
224- val diff = Mux (inputFp6 > lutCache_act_out(i )(j),
225- inputFp6 - lutCache_act_out(i )(j),
226- lutCache_act_out(i )(j) - inputFp6)
225+ val diff = Mux (inputFp6 > lutCache_act_out(counter_act_out )(j),
226+ inputFp6 - lutCache_act_out(counter_act_out )(j),
227+ lutCache_act_out(counter_act_out )(j) - inputFp6)
227228 diff
228229 })
229230
@@ -237,8 +238,14 @@ class QuantLut(
237238 projectedIndices(i) := minIdx
238239 }
239240 projectedDataValid := true .B
241+ when (counter_act_out === (outputnumLanes - 1 ).U ){
242+ counter_act_out := 0 .U
243+ }.otherwise{
244+ counter_act_out := counter_act_out + 1 .U
245+ }
240246 }.otherwise {
241247 projectedDataValid := false .B
248+
242249 }
243250
244251 io.projected_data.valid := projectedDataValid
@@ -252,7 +259,7 @@ class QuantLut(
252259 io.spad_projected_data(i).resp.ready := io.spad_deprojected_data(i).resp.ready
253260
254261 when(io.spad_projected_data(i).resp.valid) {
255- val deprojected_bits = Wire (Vec (32 , UInt (6 .W )))
262+ val deprojected_bits = Wire (Vec (outputnumLanes , UInt (6 .W )))
256263 for (k <- 0 until 32 ) {
257264 deprojected_bits(k) := 0 .U
258265 }
@@ -261,7 +268,7 @@ class QuantLut(
261268 val chunk_4bit = io.spad_projected_data(i).resp.bits.data((k+ 1 )* 4 - 1 , k* 4 )
262269 deprojected_bits(k) := lutCache_act_in(counter_act)(chunk_4bit)
263270 }
264- when(counter_act === (lutConfig.numEntries - 1 ).U ){
271+ when(counter_act === (lutConfig( 0 )._1 - 1 ).U ){
265272 counter_act := 0 .U
266273 }.otherwise{
267274 counter_act := counter_act + 1 .U
@@ -271,7 +278,7 @@ class QuantLut(
271278 val chunk_4bit = io.spad_projected_data(i).resp.bits.data((k+ 1 )* 4 - 1 , k* 4 )
272279 deprojected_bits(k) := lutCache_weight(counter_w)(chunk_4bit)
273280 }
274- when(counter_w === (lutConfig.numEntries - 1 ).U ){
281+ when(counter_w === (lutConfig( 1 )._1 - 1 ).U ){
275282 counter_w := 0 .U
276283 }.otherwise{
277284 counter_w := counter_w + 1 .U
0 commit comments