Skip to content

Commit 79047bf

Browse files
authored
[AMD] Hipify torchaudio_decoder
Differential Revision: D64298970 Pull Request resolved: #3843
1 parent b4a286a commit 79047bf

8 files changed

+36
-30
lines changed

src/libtorchaudio/cuctc/include/ctc_prefix_decoder.h

-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
#ifndef __ctc_prefix_decoder_h_
2727
#define __ctc_prefix_decoder_h_
2828

29-
#include <cuda_runtime.h>
3029
#include <cstdint>
3130
#include <tuple>
3231
#include <vector>

src/libtorchaudio/cuctc/include/ctc_prefix_decoder_host.h

-18
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,6 @@
2626
#ifndef __ctc_prefix_decoder_host_h_
2727
#define __ctc_prefix_decoder_host_h_
2828

29-
#include <cuda_runtime.h>
30-
31-
#define CUDA_CHECK(X) \
32-
do { \
33-
auto result = X; \
34-
if (result != cudaSuccess) { \
35-
const char* p_err_str = cudaGetErrorName(result); \
36-
fprintf( \
37-
stderr, \
38-
"File %s Line %d %s returned %s.\n", \
39-
__FILE__, \
40-
__LINE__, \
41-
#X, \
42-
p_err_str); \
43-
abort(); \
44-
} \
45-
} while (0)
46-
4729
#define CHECK(X, ERROR_INFO) \
4830
do { \
4931
auto result = (X); \

src/libtorchaudio/cuctc/src/bitonic_topk/bitonic_sort.cuh

+4
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,13 @@ constexpr inline __host__ __device__ bool isPo2(IntType num) {
1616
}
1717

1818
inline __device__ int laneId() {
19+
#ifndef USE_ROCM
1920
int id;
2021
asm("mov.s32 %0, %%laneid;" : "=r"(id));
2122
return id;
23+
#else
24+
return __lane_id();
25+
#endif
2226
}
2327
/**
2428
* @brief Shuffle the data inside a warp

src/libtorchaudio/cuctc/src/bitonic_topk/pow2_utils.cuh

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ namespace cu_ctc {
1212
* @tparam IntType data type (checked only for integers)
1313
*/
1414
template <typename IntType>
15-
constexpr __device__ IntType log2(IntType num, IntType ret = IntType(0)) {
15+
constexpr __host__ __device__ IntType
16+
log2(IntType num, IntType ret = IntType(0)) {
1617
return num <= IntType(1) ? ret : log2(num >> IntType(1), ++ret);
1718
}
1819

src/libtorchaudio/cuctc/src/bitonic_topk/warpsort_topk.cuh

+4-4
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ class warp_sort_filtered : public warp_sort<Capacity, Ascending, T, IdxT> {
313313

314314
__device__ __forceinline__ void merge_buf_() {
315315
topk::bitonic<kMaxBufLen>(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_);
316-
this->merge_in<kMaxBufLen>(val_buf_, idx_buf_);
316+
this->template merge_in<kMaxBufLen>(val_buf_, idx_buf_);
317317
buf_len_ = 0;
318318
set_k_th_(); // contains warp sync
319319
#pragma unroll
@@ -385,7 +385,7 @@ class warp_sort_immediate : public warp_sort<Capacity, Ascending, T, IdxT> {
385385
if (buf_len_ == kMaxArrLen) {
386386
topk::bitonic<kMaxArrLen>(!Ascending, kWarpWidth)
387387
.sort(val_buf_, idx_buf_);
388-
this->merge_in<kMaxArrLen>(val_buf_, idx_buf_);
388+
this->template merge_in<kMaxArrLen>(val_buf_, idx_buf_);
389389
#pragma unroll
390390
for (int i = 0; i < kMaxArrLen; i++) {
391391
val_buf_[i] = kDummy;
@@ -398,7 +398,7 @@ class warp_sort_immediate : public warp_sort<Capacity, Ascending, T, IdxT> {
398398
if (buf_len_ != 0) {
399399
topk::bitonic<kMaxArrLen>(!Ascending, kWarpWidth)
400400
.sort(val_buf_, idx_buf_);
401-
this->merge_in<kMaxArrLen>(val_buf_, idx_buf_);
401+
this->template merge_in<kMaxArrLen>(val_buf_, idx_buf_);
402402
}
403403
}
404404

@@ -421,7 +421,7 @@ constexpr inline __host__ __device__ IntType ceildiv(IntType a, IntType b) {
421421
return (a + b - 1) / b;
422422
}
423423
template <typename IntType>
424-
constexpr inline __device__ IntType roundUp256(IntType num) {
424+
constexpr inline __host__ __device__ IntType roundUp256(IntType num) {
425425
// return (num + 255) / 256 * 256;
426426
constexpr int MASK = 255;
427427
return (num + MASK) & (~MASK);

src/libtorchaudio/cuctc/src/ctc_prefix_decoder.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
#include <cuda_runtime.h>
2727

28-
#include "include/ctc_prefix_decoder.h"
29-
#include "include/ctc_prefix_decoder_host.h"
28+
#include "../include/ctc_prefix_decoder.h"
29+
#include "../include/ctc_prefix_decoder_host.h"
3030

3131
#include "device_data_wrap.h"
3232
#include "device_log_prob.cuh"

src/libtorchaudio/cuctc/src/ctc_prefix_decoder_kernel_v2.cu

+6-3
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@
2323
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
2424
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
#include <float.h>
2627
#include <algorithm>
28+
#include "../include/ctc_prefix_decoder_host.h"
2729
#include "ctc_fast_divmod.cuh"
2830
#include "cub/cub.cuh"
2931
#include "device_data_wrap.h"
3032
#include "device_log_prob.cuh"
31-
#include "include/ctc_prefix_decoder_host.h"
3233

3334
#include "bitonic_topk/warpsort_topk.cuh"
3435

@@ -630,7 +631,8 @@ int CTC_prob_first_step_V2(
630631
num_of_subwarp, beam));
631632
int smem_size =
632633
block_sort_smem_size + beam * sizeof(float) + beam * sizeof(int);
633-
FirstMatrixFuns[fun_idx]<<<grid, threads_per_block, smem_size, stream>>>(
634+
auto kernel = FirstMatrixFuns[fun_idx];
635+
kernel<<<grid, threads_per_block, smem_size, stream>>>(
634636
(*log_prob_struct),
635637
step,
636638
pprev,
@@ -766,7 +768,8 @@ int CTC_prob_topK_V2(
766768
int num_of_subwarp = threads_per_block0 / std::min<int>(32, actual_capacity);
767769
int smem_size = cu_ctc::topk::calc_smem_size_for_block_wide<float, int>(
768770
num_of_subwarp, beam);
769-
BitonicTopkFuns[fun_idx]<<<grid, block, smem_size, stream>>>(
771+
auto kernel = BitonicTopkFuns[fun_idx];
772+
kernel<<<grid, block, smem_size, stream>>>(
770773
(*log_prob_struct),
771774
step,
772775
ptable,

src/libtorchaudio/cuctc/src/device_data_wrap.h

+18-1
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,26 @@
2424
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
#pragma once
27+
#include <cuda_runtime.h>
2728
#include <iostream>
2829
#include <vector>
29-
#include "include/ctc_prefix_decoder_host.h"
30+
#include "../include/ctc_prefix_decoder_host.h"
31+
32+
#define CUDA_CHECK(X) \
33+
do { \
34+
auto result = X; \
35+
if (result != cudaSuccess) { \
36+
const char* p_err_str = cudaGetErrorName(result); \
37+
fprintf( \
38+
stderr, \
39+
"File %s Line %d %s returned %s.\n", \
40+
__FILE__, \
41+
__LINE__, \
42+
#X, \
43+
p_err_str); \
44+
abort(); \
45+
} \
46+
} while (0)
3047

3148
namespace cu_ctc {
3249
constexpr size_t ALIGN_BYTES = 128;

0 commit comments

Comments
 (0)