Skip to content

Commit b4a286a

Browse files
authored
[AMD] hipify torchaudio
Differential Revision: D64184710 Pull Request resolved: #3840
1 parent 3f05699 commit b4a286a

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh

+8
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ __global__ void ReduceMax2D(
3939

4040
CAST_DTYPE shf;
4141
for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) {
42+
#ifndef USE_ROCM
4243
shf = __shfl_down_sync(0xFFFFFFFF, val, stride);
44+
#else
45+
shf = __shfl_down(val, stride);
46+
#endif
4347
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
4448
if (shf > val) {
4549
val = shf;
@@ -81,7 +85,11 @@ __global__ void ReduceLogSumExpGivenMax2D(
8185

8286
CAST_DTYPE shf;
8387
for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) {
88+
#ifndef USE_ROCM
8489
shf = __shfl_down_sync(0xFFFFFFFF, val, stride);
90+
#else
91+
shf = __shfl_down(val, stride);
92+
#endif
8593
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
8694
val = val + shf;
8795
}

src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh

+16
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,11 @@ __device__ void ComputeAlphas(
126126

127127
#pragma unroll
128128
for (int i = 1; i < warpSize; i <<= 1) {
129+
#ifndef USE_ROCM
129130
val = __shfl_up_sync(0xffffffff, skip_prob, i);
131+
#else
132+
val = __shfl_up(skip_prob, i);
133+
#endif
130134
if (i <= threadIdx.x) {
131135
skip_prob = skip_prob + val;
132136
}
@@ -150,7 +154,11 @@ __device__ void ComputeAlphas(
150154
CAST_DTYPE out = val;
151155

152156
for (int i = 1; i < warpSize; ++i) {
157+
#ifndef USE_ROCM
153158
val = __shfl_up_sync(0xffffffff, val, 1);
159+
#else
160+
val = __shfl_up(val, 1);
161+
#endif
154162
if (i == threadIdx.x) {
155163
val = math::lse(val + skip_prob, emit);
156164
out = val;
@@ -225,7 +233,11 @@ __device__ void ComputeBetasCosts(
225233

226234
#pragma unroll
227235
for (int i = 1; i < warpSize; i <<= 1) {
236+
#ifndef USE_ROCM
228237
val = __shfl_up_sync(0xffffffff, skip_prob, i);
238+
#else
239+
val = __shfl_up(skip_prob, i);
240+
#endif
229241
if (i <= threadIdx.x) {
230242
skip_prob = skip_prob + val;
231243
}
@@ -248,7 +260,11 @@ __device__ void ComputeBetasCosts(
248260
CAST_DTYPE out = val;
249261

250262
for (int i = 1; i < warpSize; ++i) {
263+
#ifndef USE_ROCM
251264
val = __shfl_up_sync(0xffffffff, val, 1);
265+
#else
266+
val = __shfl_up(val, 1);
267+
#endif
252268
if (i == threadIdx.x) {
253269
val = math::lse(val + skip_prob, emit);
254270
out = val;

0 commit comments

Comments
 (0)