@@ -126,7 +126,11 @@ __device__ void ComputeAlphas(
126
126
127
127
#pragma unroll
128
128
for (int i = 1 ; i < warpSize ; i <<= 1 ) {
129
+ #ifndef USE_ROCM
129
130
val = __shfl_up_sync (0xffffffff , skip_prob, i);
131
+ #else
132
+ val = __shfl_up (skip_prob, i);
133
+ #endif
130
134
if (i <= threadIdx .x ) {
131
135
skip_prob = skip_prob + val;
132
136
}
@@ -150,7 +154,11 @@ __device__ void ComputeAlphas(
150
154
CAST_DTYPE out = val;
151
155
152
156
for (int i = 1 ; i < warpSize ; ++i) {
157
+ #ifndef USE_ROCM
153
158
val = __shfl_up_sync (0xffffffff , val, 1 );
159
+ #else
160
+ val = __shfl_up (val, 1 );
161
+ #endif
154
162
if (i == threadIdx .x ) {
155
163
val = math::lse (val + skip_prob, emit);
156
164
out = val;
@@ -225,7 +233,11 @@ __device__ void ComputeBetasCosts(
225
233
226
234
#pragma unroll
227
235
for (int i = 1 ; i < warpSize ; i <<= 1 ) {
236
+ #ifndef USE_ROCM
228
237
val = __shfl_up_sync (0xffffffff , skip_prob, i);
238
+ #else
239
+ val = __shfl_up (skip_prob, i);
240
+ #endif
229
241
if (i <= threadIdx .x ) {
230
242
skip_prob = skip_prob + val;
231
243
}
@@ -248,7 +260,11 @@ __device__ void ComputeBetasCosts(
248
260
CAST_DTYPE out = val;
249
261
250
262
for (int i = 1 ; i < warpSize ; ++i) {
263
+ #ifndef USE_ROCM
251
264
val = __shfl_up_sync (0xffffffff , val, 1 );
265
+ #else
266
+ val = __shfl_up (val, 1 );
267
+ #endif
252
268
if (i == threadIdx .x ) {
253
269
val = math::lse (val + skip_prob, emit);
254
270
out = val;
0 commit comments