@@ -833,9 +833,9 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
833
833
val dataB_unpadded = MuxCase (readData(cntl.b_bank), Seq (cntl.accumulate_zeros -> 0 .U , cntl.b_read_from_acc -> accReadData(cntl.b_bank_acc)))
834
834
val dataD_unpadded = MuxCase (readData(cntl.d_bank), Seq (cntl.preload_zeros -> 0 .U , cntl.d_read_from_acc -> accReadData(cntl.d_bank_acc)))
835
835
836
- val dataA = VecInit (dataA_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.a_unpadded_cols, d, inputType.zero)})
837
- val dataB = VecInit (dataB_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.b_unpadded_cols, d, inputType.zero)})
838
- val dataD = VecInit (dataD_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.d_unpadded_cols, d, inputType.zero)})
836
+ val dataA = VecInit (dataA_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.a_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayInputType)) )
837
+ val dataB = VecInit (dataB_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.b_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)) )
838
+ val dataD = VecInit (dataD_unpadded.asTypeOf(Vec (block_size, inputType)).zipWithIndex.map { case (d, i) => Mux (i.U < cntl.d_unpadded_cols, d, inputType.zero)}.map(d => d.asTypeOf(inputType).withWidthOf(spatialArrayWeightType)) )
839
839
840
840
// Pop responses off the scratchpad io ports
841
841
when (mesh_cntl_signals_q.io.deq.fire) {
@@ -876,9 +876,9 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
876
876
mesh.io.b.valid := cntl.b_fire && dataB_valid
877
877
mesh.io.d.valid := cntl.d_fire && dataD_valid
878
878
879
- mesh.io.a.bits := dataA.asTypeOf(Vec (meshRows, Vec (tileRows, inputType )))
880
- mesh.io.b.bits := dataB.asTypeOf(Vec (meshColumns, Vec (tileColumns, inputType )))
881
- mesh.io.d.bits := dataD.asTypeOf(Vec (meshColumns, Vec (tileColumns, inputType )))
879
+ mesh.io.a.bits := dataA.asTypeOf(Vec (meshRows, Vec (tileRows, spatialArrayInputType )))
880
+ mesh.io.b.bits := dataB.asTypeOf(Vec (meshColumns, Vec (tileColumns, spatialArrayWeightType )))
881
+ mesh.io.d.bits := dataD.asTypeOf(Vec (meshColumns, Vec (tileColumns, spatialArrayWeightType )))
882
882
883
883
mesh.io.req.valid := mesh_cntl_signals_q.io.deq.fire && (cntl.a_fire || cntl.b_fire || cntl.d_fire)
884
884
@@ -888,13 +888,13 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
888
888
}
889
889
890
890
when (cntl_valid && cntl.perform_single_preload) {
891
- mesh.io.a.bits := Mux (a_should_be_fed_into_transposer, dataA.asUInt, 0 .U ).asTypeOf(Vec (meshRows, Vec (tileRows, inputType )))
892
- mesh.io.b.bits := Mux (b_should_be_fed_into_transposer, dataB.asUInt, 0 .U ).asTypeOf(Vec (meshColumns, Vec (tileColumns, inputType )))
891
+ mesh.io.a.bits := Mux (a_should_be_fed_into_transposer, dataA.asUInt, 0 .U ).asTypeOf(Vec (meshRows, Vec (tileRows, spatialArrayInputType )))
892
+ mesh.io.b.bits := Mux (b_should_be_fed_into_transposer, dataB.asUInt, 0 .U ).asTypeOf(Vec (meshColumns, Vec (tileColumns, spatialArrayWeightType )))
893
893
}
894
894
895
895
when (cntl_valid && cntl.perform_single_mul) {
896
- mesh.io.a.bits := Mux (a_should_be_fed_into_transposer, 0 .U , dataA.asUInt).asTypeOf(Vec (meshRows, Vec (tileRows, inputType )))
897
- mesh.io.b.bits := Mux (b_should_be_fed_into_transposer, 0 .U , dataB.asUInt).asTypeOf(Vec (meshColumns, Vec (tileColumns, inputType )))
896
+ mesh.io.a.bits := Mux (a_should_be_fed_into_transposer, 0 .U , dataA.asUInt).asTypeOf(Vec (meshRows, Vec (tileRows, spatialArrayInputType )))
897
+ mesh.io.b.bits := Mux (b_should_be_fed_into_transposer, 0 .U , dataB.asUInt).asTypeOf(Vec (meshColumns, Vec (tileColumns, spatialArrayWeightType )))
898
898
mesh.io.req.bits.tag.addr.make_this_garbage()
899
899
}
900
900
0 commit comments