Skip to content

Commit a09edd1

Browse files
authored
[WebGPU] support ResizeNearestNeighborGrad kernel (#7354)
1 parent 963589e commit a09edd1

File tree

4 files changed

+187
-6
lines changed

4 files changed

+187
-6
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {KernelConfig, KernelFunc, ResizeNearestNeighborGrad, ResizeNearestNeighborGradAttrs, ResizeNearestNeighborGradInputs, TensorInfo} from '@tensorflow/tfjs-core';
19+
20+
import {WebGPUBackend} from '../backend_webgpu';
21+
import {ResizeNearestNeigborBackpropProgram} from '../resize_nearest_neighbor_backprop_webgpu';
22+
23+
export function resizeNearestNeighborGrad(args: {
24+
inputs: ResizeNearestNeighborGradInputs,
25+
backend: WebGPUBackend,
26+
attrs: ResizeNearestNeighborGradAttrs
27+
}): TensorInfo {
28+
const {inputs, backend, attrs} = args;
29+
const {images, dy} = inputs;
30+
const {alignCorners} = attrs;
31+
32+
const [, xHeight, xWidth] = images.shape as [number, number, number, number];
33+
const [, yHeight, yWidth] = dy.shape as [number, number, number, number];
34+
35+
const effectiveXSize: [number, number] = [
36+
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
37+
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
38+
];
39+
40+
const effectiveYSize: [number, number] = [
41+
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
42+
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
43+
];
44+
45+
const heightScale = effectiveXSize[0] / effectiveYSize[0];
46+
const widthScale = effectiveXSize[1] / effectiveYSize[1];
47+
48+
const invHeightScale = 1 / heightScale;
49+
const invWidthScale = 1 / widthScale;
50+
51+
// This defines the size of the window of values around a particular
52+
// index in dy that we want to search for contributions to dx.
53+
const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
54+
const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
55+
56+
const program = new ResizeNearestNeigborBackpropProgram(
57+
images.shape as [number, number, number, number], alignCorners);
58+
const uniformData = [
59+
{type: 'int32', data: effectiveXSize},
60+
{type: 'int32', data: effectiveYSize},
61+
{type: 'float32', data: [invHeightScale]},
62+
{type: 'float32', data: [invWidthScale]},
63+
{type: 'int32', data: [winHeight]}, {type: 'int32', data: [winWidth]}
64+
];
65+
return backend.runWebGPUProgram(program, [dy], dy.dtype, uniformData);
66+
}
67+
68+
export const resizeNearestNeighborGradConfig: KernelConfig = {
69+
kernelName: ResizeNearestNeighborGrad,
70+
backendName: 'webgpu',
71+
kernelFunc: resizeNearestNeighborGrad as unknown as KernelFunc
72+
};

tfjs-backend-webgpu/src/register_all_kernels.ts

+2
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ import {relu6Config} from './kernels/Relu6';
128128
import {reshapeConfig} from './kernels/Reshape';
129129
import {resizeBilinearConfig} from './kernels/ResizeBilinear';
130130
import {resizeNearestNeighborConfig} from './kernels/ResizeNearestNeighbor';
131+
import {resizeNearestNeighborGradConfig} from './kernels/ResizeNearestNeighborGrad';
131132
import {reverseConfig} from './kernels/Reverse';
132133
import {rotateWithOffsetConfig} from './kernels/RotateWithOffset';
133134
import {roundConfig} from './kernels/Round';
@@ -277,6 +278,7 @@ const kernelConfigs: KernelConfig[] = [
277278
reshapeConfig,
278279
resizeBilinearConfig,
279280
resizeNearestNeighborConfig,
281+
resizeNearestNeighborGradConfig,
280282
reverseConfig,
281283
rotateWithOffsetConfig,
282284
roundConfig,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
19+
import {computeDispatch, flatDispatchLayout} from './webgpu_util';
20+
21+
export class ResizeNearestNeigborBackpropProgram implements WebGPUProgram {
22+
outputShape: number[];
23+
shaderKey: string;
24+
dispatchLayout: {x: number[]};
25+
dispatch: [number, number, number];
26+
variableNames = ['dy'];
27+
uniforms =
28+
`effectiveXSize : vec2<i32>, effectiveYSize : vec2<i32>, invHeightScale : f32, invWidthScale : f32,
29+
winHeight : i32, winWidth : i32,`;
30+
workgroupSize: [number, number, number] = [64, 1, 1];
31+
alignCorners: boolean;
32+
size = true;
33+
34+
constructor(
35+
inputShape: [number, number, number, number], alignCorners: boolean) {
36+
this.outputShape = inputShape;
37+
38+
this.dispatchLayout = flatDispatchLayout(this.outputShape);
39+
this.dispatch = computeDispatch(
40+
this.dispatchLayout, this.outputShape, this.workgroupSize);
41+
42+
this.alignCorners = alignCorners;
43+
this.shaderKey = `resizeNearestNeigborBackprop_${alignCorners}`;
44+
}
45+
46+
getUserCode(): string {
47+
const userCode = `
48+
${main('index')} {
49+
if (index < uniforms.size) {
50+
let coords = getOutputCoords();
51+
let b = coords[0];
52+
let d = coords[3];
53+
let r = coords[1];
54+
let c = coords[2];
55+
56+
var accumulator = 0.0;
57+
58+
// Compute bounds for where in dy we will look
59+
let startRLerp = floor(f32(r) * uniforms.invHeightScale);
60+
let startDyR = i32(floor(startRLerp - f32(uniforms.winHeight / 2)));
61+
62+
let startCLerp = floor(f32(c) * uniforms.invWidthScale);
63+
let startDyC = i32(floor(startCLerp - f32(uniforms.winWidth / 2)));
64+
65+
// Loop over dy
66+
for (var dyROffset = 0; dyROffset < uniforms.winHeight; dyROffset++) {
67+
let dyR = startDyR + dyROffset;
68+
69+
// Guard against the window exceeding the bounds of dy
70+
if (dyR < 0 || dyR >= uniforms.dyShape[1]) {
71+
continue;
72+
}
73+
74+
for (var dyCOffset = 0; dyCOffset < uniforms.winWidth; dyCOffset++) {
75+
let dyC = startDyC + dyCOffset;
76+
77+
// Guard against the window exceeding the bounds of dy
78+
if (dyC < 0 || dyC >= uniforms.dyShape[2]) {
79+
continue;
80+
}
81+
82+
let sourceFracRow = f32(uniforms.effectiveXSize[0]) *
83+
(f32(dyR) / f32(uniforms.effectiveYSize[0]));
84+
85+
let sourceFracCol = f32(uniforms.effectiveXSize[1]) *
86+
(f32(dyC) / f32(uniforms.effectiveYSize[1]));
87+
88+
let sourceNearestRow =
89+
i32(min(f32(uniforms.outShape[1] - 1),
90+
${
91+
this.alignCorners ? 'floor(sourceFracRow + 0.5)' :
92+
'floor(sourceFracRow)'}));
93+
94+
let sourceNearestCol =
95+
i32(min(f32(uniforms.outShape[2] - 1),
96+
${
97+
this.alignCorners ? 'floor(sourceFracCol + 0.5)' :
98+
'floor(sourceFracCol)'}));
99+
100+
if (r == sourceNearestRow && c == sourceNearestCol) {
101+
accumulator += getDy(b, dyR, dyC, d);
102+
}
103+
}
104+
}
105+
// End loop over dy
106+
107+
setOutputAtIndex(index, accumulator);
108+
}
109+
}
110+
`;
111+
return userCode;
112+
}
113+
}

tfjs-backend-webgpu/src/setup_test.ts

-6
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,6 @@ const TEST_FILTERS: TestFilter[] = [
8282
'gradients', // Not yet implemented
8383
]
8484
},
85-
{
86-
startsWith: 'resizeNearestNeighbor ',
87-
excludes: [
88-
'gradients', // Not yet implemented
89-
]
90-
},
9185

9286
// exclude unsupported kernels and to be fixed cases
9387
{

0 commit comments

Comments
 (0)