Skip to content

Commit d45c6af

Browse files
authored
[webgpu] Update INT_DIV (#7792)
1 parent 4a11300 commit d45c6af

File tree

4 files changed

+16
-52
lines changed

4 files changed

+16
-52
lines changed

tfjs-backend-webgpu/src/binary_op_util.ts

+11-33
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ export enum BinaryOpType {
2323
DIV,
2424
ELU_DER,
2525
EQUAL,
26+
FLOOR_DIV,
2627
GREATER,
2728
GREATER_EQUAL,
28-
INT_DIV,
2929
LESS,
3030
LESS_EQUAL,
3131
LOGICAL_AND,
@@ -56,6 +56,13 @@ const EQUAL = `
5656
let one = sign(b) * 0 + 1;
5757
let resultTemp = select(zero, one, a == b);
5858
`;
59+
const FLOOR_DIV = `
60+
let remainder =
61+
select(a % b, round(a % b), (round(a) == a) & (round(b) == b));
62+
let quotient = (a - remainder) / b;
63+
let resultTemp =
64+
round(select(quotient, quotient - 1, sign(remainder) == -sign(b)));
65+
`;
5966
const GREATER = `
6067
let zero = sign(a) * 0 + 0;
6168
let one = sign(b) * 0 + 1;
@@ -66,36 +73,6 @@ const GREATER_EQUAL = `
6673
let one = sign(b) * 0 + 1;
6774
let resultTemp = select(zero, one, a >= b);
6875
`;
69-
70-
const INT_DIV = `
71-
let s = sign(a) * sign(b);
72-
let ia = i32(round(a));
73-
let ib = i32(round(b));
74-
return f32(idiv(ia, ib, s));
75-
`;
76-
const INT_DIV_VEC4 = `
77-
let ia = vec4<i32>(round(a));
78-
let ib = vec4<i32>(round(b));
79-
let cond = ib != vec4<i32>(0);
80-
var resultTemp = vec4<i32>(0);
81-
let s = sign(a) * sign(b);
82-
83-
// Windows (D3D) wants guaranteed non-zero int division at compile-time.
84-
if (cond[0]) {
85-
resultTemp[0] = idiv(ia[0], ib[0], s[0]);
86-
}
87-
if (cond[1]) {
88-
resultTemp[1] = idiv(ia[1], ib[1], s[1]);
89-
}
90-
if (cond[2]) {
91-
resultTemp[2] = idiv(ia[2], ib[2], s[2]);
92-
}
93-
if (cond[3]) {
94-
resultTemp[3] = idiv(ia[3], ib[3], s[3]);
95-
}
96-
return vec4<f32>(resultTemp);
97-
`;
98-
9976
const LESS = `
10077
let zero = sign(a) * 0 + 0;
10178
let one = sign(b) * 0 + 1;
@@ -265,14 +242,15 @@ export function getBinaryOpString(
265242
case BinaryOpType.EQUAL:
266243
doOpSnippet = EQUAL;
267244
break;
245+
case BinaryOpType.FLOOR_DIV:
246+
doOpSnippet = FLOOR_DIV;
247+
break;
268248
case BinaryOpType.GREATER:
269249
doOpSnippet = GREATER;
270250
break;
271251
case BinaryOpType.GREATER_EQUAL:
272252
doOpSnippet = GREATER_EQUAL;
273253
break;
274-
case BinaryOpType.INT_DIV:
275-
return useVec4 ? INT_DIV_VEC4 : INT_DIV;
276254
case BinaryOpType.LESS:
277255
doOpSnippet = LESS;
278256
break;

tfjs-backend-webgpu/src/kernels/FloorDiv.ts

+5-3
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@ import {BinaryOpType} from '../binary_op_util';
2121
import {binaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';
2222
import {floorDivImplCPU} from '../kernel_utils/shared';
2323

24-
export const floorDiv =
25-
binaryKernelFunc({opType: BinaryOpType.INT_DIV,
26-
cpuKernelImpl: floorDivImplCPU, dtype: 'int32'});
24+
export const floorDiv = binaryKernelFunc({
25+
opType: BinaryOpType.FLOOR_DIV,
26+
cpuKernelImpl: floorDivImplCPU,
27+
dtype: 'int32'
28+
});
2729

2830
export const floorDivConfig: KernelConfig = {
2931
kernelName: FloorDiv,

tfjs-backend-webgpu/src/setup_test.ts

-7
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,6 @@ const TEST_FILTERS: TestFilter[] = [
168168
'indices invalid',
169169
],
170170
},
171-
{
172-
startsWith: 'floorDiv ',
173-
excludes: [
174-
// float32 inputs with nonzero fractional part should not be rounded
175-
'floorDiv float32',
176-
],
177-
},
178171

179172
// exclude unsupported kernels and to be fixed cases
180173
{

tfjs-backend-webgpu/src/webgpu_program.ts

-9
Original file line numberDiff line numberDiff line change
@@ -417,15 +417,6 @@ const commonSnippet = `
417417
return coords.x*shapeStrides.x + coords.y*shapeStrides.y + coords.z*shapeStrides.z + coords.w*shapeStrides.w + coords.u*shapeStrides.u + coords.v*shapeStrides.v;
418418
}
419419
420-
fn idiv(a: i32, b: i32, sign: f32) -> i32 {
421-
var res: i32 = a / b;
422-
let modulo: i32 = a % b;
423-
if (sign < 0. && modulo != 0) {
424-
res = res - 1;
425-
}
426-
return res;
427-
}
428-
429420
// NaN defination in IEEE 754-1985 is :
430421
// - sign = either 0 or 1.
431422
// - biased exponent = all 1 bits.

0 commit comments

Comments
 (0)