Skip to content

Commit 565cc84

Browse files
Amanda ShiAmanda Shi
authored andcommitted
fix lut write logic
1 parent ddced60 commit 565cc84

File tree

2 files changed

+57
-23
lines changed

2 files changed

+57
-23
lines changed

src/main/scala/gemmini/QuantLut.scala

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,15 @@ class QuantLut(
4444
val io = IO(new QuantLutIO(lutConfig, outputnumLanes, sp_bank_entries, sp_banks, sp_width, sp_width_projected, iterator_bitwidth))
4545
val rdataWidth = lutConfig.rdataWidth
4646
val raddrWidth = lutConfig.raddrWidth
47-
val lutCache_weight_0 = Seq.fill(lutConfig.numEntries)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
48-
val lutCache_weight_1 = Seq.fill(lutConfig.numEntries)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
49-
val lutCache_act_in_0 = Seq.fill(lutConfig.numEntries)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
50-
val lutCache_act_in_1 = Seq.fill(lutConfig.numEntries)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
51-
val lutCache_act_out_0 = Seq.fill(outputnumLanes)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
52-
val lutCache_act_out_1 = Seq.fill(outputnumLanes)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
47+
val lutCache_weight_0 = Seq.fill(lutConfig(0)._1)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
48+
val lutCache_weight_1 = Seq.fill(lutConfig(0)._1)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
49+
val lutCache_act_in_0 = Seq.fill(lutConfig(1)._1)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
50+
val lutCache_act_in_1 = Seq.fill(lutConfig(1)._1)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
51+
val lutCache_act_out_0 = Seq.fill(lutConfig(2)._1)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
52+
val lutCache_act_out_1 = Seq.fill(lutConfig(2)._1)(RegInit(VecInit(Seq.fill(16)(0.U(rdataWidth.W)))))
5353
//io.lut_write.ready := !io.lutReadEnable
5454

55-
val lutCache_act_in = WireInit(VecInit.fill(lutConfig.numEntries)(VecInit.fill(16)(0.U(rdataWidth.W))))
55+
val lutCache_act_in = WireInit(VecInit.fill(lutConfig(0)._1)(VecInit.fill(16)(0.U(rdataWidth.W))))
5656
val lutCache_act_in_flag = RegInit(false.B)
5757
val lutCache_act_in_buffer_0_read_enable = RegInit(false.B)
5858
val lutCache_act_in_buffer_1_read_enable = RegInit(false.B)
@@ -63,7 +63,7 @@ class QuantLut(
6363

6464
when(io.lut_write_act_in.fire){
6565
when(lutCache_act_in_flag === false.B){
66-
for (lane <- 0 until lutConfig.numEntries) {
66+
for (lane <- 0 until lutConfig(0)._1) {
6767
for (entry <- 0 until 16) {
6868
lutCache_act_in_0(lane)(entry) := io.lut_write_act_in.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
6969
//lutCache_act_in_0(lane)(entry) := io.lut_write_act_in.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -72,7 +72,7 @@ class QuantLut(
7272
lutCache_act_in_flag := ~lutCache_act_in_flag
7373
lutCache_act_in_buffer_0_read_enable := true.B
7474
}.otherwise {
75-
for (lane <- 0 until lutConfig.numEntries) {
75+
for (lane <- 0 until lutConfig(0)._1) {
7676
for (entry <- 0 until 16) {
7777
lutCache_act_in_1(lane)(entry) := io.lut_write_act_in.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
7878
//lutCache_act_in_1(lane)(entry) := io.lut_write_act_in.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -108,7 +108,7 @@ class QuantLut(
108108

109109
io.lut_write_act_in.ready := !lutCache_act_in_buffer_0_read_enable || !lutCache_act_in_buffer_1_read_enable
110110

111-
val lutCache_weight = WireInit(VecInit.fill(lutConfig.numEntries)(VecInit.fill(16)(0.U(rdataWidth.W))))
111+
val lutCache_weight = WireInit(VecInit.fill(lutConfig(1)._1)(VecInit.fill(16)(0.U(rdataWidth.W))))
112112
val lutCache_weight_flag = RegInit(false.B)
113113
val lutCache_weight_buffer_0_read_enable = RegInit(false.B)
114114
val lutCache_weight_buffer_1_read_enable = RegInit(false.B)
@@ -117,7 +117,7 @@ class QuantLut(
117117

118118
when(io.lut_write_weight.fire){
119119
when(lutCache_weight_flag === false.B){
120-
for (lane <- 0 until lutConfig.numEntries) {
120+
for (lane <- 0 until lutConfig(1)._1) {
121121
for (entry <- 0 until 16) {
122122
lutCache_weight_0(lane)(entry) := io.lut_write_weight.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
123123
//lutCache_weight_0(lane)(entry) := io.lut_write_weight.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -126,7 +126,7 @@ class QuantLut(
126126
lutCache_weight_flag := ~lutCache_weight_flag
127127
lutCache_weight_buffer_0_read_enable := true.B
128128
}.otherwise {
129-
for (lane <- 0 until lutConfig.numEntries) {
129+
for (lane <- 0 until lutConfig(1)._1) {
130130
for (entry <- 0 until 16) {
131131
lutCache_weight_1(lane)(entry) := io.lut_write_weight.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
132132
//lutCache_weight_1(lane)(entry) := io.lut_write_weight.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -169,7 +169,7 @@ class QuantLut(
169169

170170
when(io.lut_write_act_out.fire){
171171
when(lutCache_act_out_flag === false.B){
172-
for (lane <- 0 until lutConfig.numEntries) {
172+
for (lane <- 0 until lutConfig(2)._1) {
173173
for (entry <- 0 until 16) {
174174
lutCache_act_out_0(lane)(entry) := io.lut_write_act_out.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
175175
//lutCache_act_out_0(lane)(entry) := io.lut_write_act_out.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -178,7 +178,7 @@ class QuantLut(
178178
lutCache_act_out_flag := ~lutCache_act_out_flag
179179
lutCache_act_out_buffer_0_read_enable := true.B
180180
}.otherwise {
181-
for (lane <- 0 until lutConfig.numEntries) {
181+
for (lane <- 0 until lutConfig(2)._1) {
182182
for (entry <- 0 until 16) {
183183
lutCache_act_out_1(lane)(entry) := io.lut_write_act_out.bits.data(lane)((entry+1)*rdataWidth-1, entry*rdataWidth)
184184
//lutCache_act_out_1(lane)(entry) := io.lut_write_act_out.bits.data(0)((entry+1)*rdataWidth-1, entry*rdataWidth)
@@ -213,17 +213,18 @@ class QuantLut(
213213

214214
io.lut_write_act_out.ready := !lutCache_act_out_buffer_0_read_enable || !lutCache_act_out_buffer_1_read_enable
215215

216-
val projectedIndices = RegInit(VecInit(Seq.fill(outputnumLanes)(0.U(raddrWidth.W))))
216+
val projectedIndices = RegInit(VecInit(Seq.fill(lutConfig(2)._1)(0.U(raddrWidth.W))))
217217
val projectedDataValid = RegInit(false.B)
218-
218+
val counter_act_out = RegInit(0.U(log2Ceil(lutConfig(2)._1).W))
219219
//TODO: double check if this is the correct way to do nearest neighbor search, aligning with the algorithm implementation
220220
when((lutCache_act_out_buffer_0_read_enable || lutCache_act_out_buffer_1_read_enable) && io.quant_fp6.valid) {
221-
for (i <- 0 until outputnumLanes) {
221+
222+
for (i <- 0 until lutConfig(2)._1) {
222223
val inputFp6 = io.quant_fp6.bits(i)
223224
val distances = VecInit((0 until 16).map { j =>
224-
val diff = Mux(inputFp6 > lutCache_act_out(i)(j),
225-
inputFp6 - lutCache_act_out(i)(j),
226-
lutCache_act_out(i)(j) - inputFp6)
225+
val diff = Mux(inputFp6 > lutCache_act_out(counter_act_out)(j),
226+
inputFp6 - lutCache_act_out(counter_act_out)(j),
227+
lutCache_act_out(counter_act_out)(j) - inputFp6)
227228
diff
228229
})
229230

@@ -237,8 +238,14 @@ class QuantLut(
237238
projectedIndices(i) := minIdx
238239
}
239240
projectedDataValid := true.B
241+
when (counter_act_out === (outputnumLanes - 1).U){
242+
counter_act_out := 0.U
243+
}.otherwise{
244+
counter_act_out := counter_act_out + 1.U
245+
}
240246
}.otherwise {
241247
projectedDataValid := false.B
248+
242249
}
243250

244251
io.projected_data.valid := projectedDataValid
@@ -252,7 +259,7 @@ class QuantLut(
252259
io.spad_projected_data(i).resp.ready := io.spad_deprojected_data(i).resp.ready
253260

254261
when(io.spad_projected_data(i).resp.valid) {
255-
val deprojected_bits = Wire(Vec(32, UInt(6.W)))
262+
val deprojected_bits = Wire(Vec(outputnumLanes, UInt(6.W)))
256263
for (k <- 0 until 32) {
257264
deprojected_bits(k) := 0.U
258265
}
@@ -261,7 +268,7 @@ class QuantLut(
261268
val chunk_4bit = io.spad_projected_data(i).resp.bits.data((k+1)*4-1, k*4)
262269
deprojected_bits(k) := lutCache_act_in(counter_act)(chunk_4bit)
263270
}
264-
when(counter_act === (lutConfig.numEntries -1).U){
271+
when(counter_act === (lutConfig(0)._1 -1).U){
265272
counter_act := 0.U
266273
}.otherwise{
267274
counter_act := counter_act + 1.U
@@ -271,7 +278,7 @@ class QuantLut(
271278
val chunk_4bit = io.spad_projected_data(i).resp.bits.data((k+1)*4-1, k*4)
272279
deprojected_bits(k) := lutCache_weight(counter_w)(chunk_4bit)
273280
}
274-
when(counter_w === (lutConfig.numEntries -1).U){
281+
when(counter_w === (lutConfig(1)._1 -1).U){
275282
counter_w := 0.U
276283
}.otherwise{
277284
counter_w := counter_w + 1.U

sync_missing_files.sh

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#!/bin/bash
2+
# Copies files that exist in the external gemmini-rocc-tests but not in the
3+
# in-tree software copy. Source this or run it directly.
4+
5+
SRC_BASE="/bwrcq/B/mshi/chipyard/cy_gpu/gemmini-rocc-tests"
6+
DST_BASE="/bwrcq/B/mshi/chipyard/cy_gpu/chipyard/generators/gemmini/software/gemmini-rocc-tests"
7+
8+
for subdir in bareMetalC include; do
9+
src="$SRC_BASE/$subdir"
10+
dst="$DST_BASE/$subdir"
11+
12+
if [ ! -d "$src" ] || [ ! -d "$dst" ]; then
13+
echo "Skipping $subdir: directory not found"
14+
continue
15+
fi
16+
17+
missing=$(comm -23 <(ls "$src" | sort) <(ls "$dst" | sort))
18+
19+
if [ -z "$missing" ]; then
20+
echo "[$subdir] No missing files."
21+
else
22+
for f in $missing; do
23+
cp "$src/$f" "$dst/$f"
24+
echo "[$subdir] Copied: $f"
25+
done
26+
fi
27+
done

0 commit comments

Comments
 (0)