diff --git a/src/main/scala/gemmini/Controller.scala b/src/main/scala/gemmini/Controller.scala index c67f92df..66d4db9b 100644 --- a/src/main/scala/gemmini/Controller.scala +++ b/src/main/scala/gemmini/Controller.scala @@ -142,7 +142,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] val max_exs = reservation_station_entries_ex val max_sts = reservation_station_entries_st - val (conv_cmd, loop_conv_unroller_busy) = withClock (gated_clock) { LoopConv(raw_cmd, reservation_station.io.conv_ld_completed, reservation_station.io.conv_st_completed, reservation_station.io.conv_ex_completed, + val (conv_cmd, loop_conv_unroller_busy) = if (has_loop_conv) withClock (gated_clock) { LoopConv(raw_cmd, reservation_station.io.conv_ld_completed, reservation_station.io.conv_st_completed, reservation_station.io.conv_ex_completed, meshRows*tileRows, coreMaxAddrBits, reservation_station_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries, inputType.getWidth, accType.getWidth, dma_maxbytes, new ConfigMvinRs1(mvin_scale_t_bits, block_stride_bits, pixel_repeats_bits), new MvinRs2(mvin_rows_bits, mvin_cols_bits, local_addr_t), @@ -151,8 +151,9 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] new PreloadRs(mvout_rows_bits, mvout_cols_bits, local_addr_t), new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), has_training_convs, has_max_pool, has_first_layer_optimizations, has_dw_convs) } + else (raw_cmd, false.B) - val (loop_cmd, loop_matmul_unroller_busy) = withClock (gated_clock) { LoopMatmul(conv_cmd, reservation_station.io.matmul_ld_completed, reservation_station.io.matmul_st_completed, reservation_station.io.matmul_ex_completed, + val (loop_cmd, loop_matmul_unroller_busy) = withClock (gated_clock) { LoopMatmul(if (has_loop_conv) conv_cmd else raw_cmd, reservation_station.io.matmul_ld_completed, reservation_station.io.matmul_st_completed, reservation_station.io.matmul_ex_completed, meshRows*tileRows, coreMaxAddrBits, reservation_station_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries, inputType.getWidth, accType.getWidth, dma_maxbytes, new MvinRs2(mvin_rows_bits, mvin_cols_bits, local_addr_t), new PreloadRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new PreloadRs(mvout_rows_bits, mvout_cols_bits, local_addr_t), diff --git a/src/main/scala/gemmini/GemminiConfigs.scala b/src/main/scala/gemmini/GemminiConfigs.scala index 98254299..40bc5303 100644 --- a/src/main/scala/gemmini/GemminiConfigs.scala +++ b/src/main/scala/gemmini/GemminiConfigs.scala @@ -88,6 +88,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( has_dw_convs: Boolean = true, has_normalizations: Boolean = false, has_first_layer_optimizations: Boolean = true, + has_loop_conv: Boolean = true, use_firesim_simulation_counters: Boolean = false,