Skip to content

Commit db099a4

Browse files
Add Draw API (#7628)
FEATURE * add draw * add tests * add comments * lint * upd * typo * cmt * rename contextOption to canvasOption * Update tfjs-core/src/ops/browser.ts Co-authored-by: Matthew Soulanille <[email protected]> * Update tfjs-core/src/ops/draw_test.ts Co-authored-by: Matthew Soulanille <[email protected]> * warn once * lint * unify options * typo * tune * tune * tune * typo --------- Co-authored-by: Matthew Soulanille <[email protected]>
1 parent cb9a98b commit db099a4

File tree

9 files changed

+340
-38
lines changed

9 files changed

+340
-38
lines changed

tfjs-backend-cpu/BUILD.bazel

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ tfjs_web_test(
4141
],
4242
headless = False,
4343
presubmit_browsers = [
44-
"bs_safari_mac",
44+
"bs_chrome_mac",
4545
],
4646
)
4747

tfjs-backend-cpu/src/kernels/Draw.ts

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {Draw, DrawAttrs, DrawInputs, KernelConfig, KernelFunc, TypedArray} from '@tensorflow/tfjs-core';
19+
import {TensorInfo} from '@tensorflow/tfjs-core';
20+
21+
import {MathBackendCPU} from '../backend_cpu';
22+
23+
export function draw(
24+
args: {inputs: DrawInputs, backend: MathBackendCPU, attrs: DrawAttrs}):
25+
TensorInfo {
26+
const {inputs, backend, attrs} = args;
27+
const {image} = inputs;
28+
const {canvas, options} = attrs;
29+
const {contextOptions, imageOptions} = options || {};
30+
const alpha = imageOptions ?.alpha || 1;
31+
32+
const contextType = contextOptions ?.contextType || '2d';
33+
if (contextType !== '2d') {
34+
throw new Error(`Context type ${
35+
contextOptions.contextType} is not supported by the CPU backend.`);
36+
}
37+
const ctx = canvas.getContext(contextType,
38+
contextOptions?.contextAttributes || {}) as CanvasRenderingContext2D ;
39+
if (ctx == null) {
40+
throw new Error(`Could not get the context with ${contextType} type.`);
41+
}
42+
43+
const [height, width] = image.shape.slice(0, 2);
44+
const depth = image.shape.length === 2 ? 1 : image.shape[2];
45+
const data = backend.data.get(image.dataId).values as TypedArray;
46+
const multiplier = image.dtype === 'float32' ? 255 : 1;
47+
const bytes = new Uint8ClampedArray(width * height * 4);
48+
49+
for (let i = 0; i < height * width; ++i) {
50+
const rgba = [0, 0, 0, 255 * alpha];
51+
52+
for (let d = 0; d < depth; d++) {
53+
const value = data[i * depth + d];
54+
55+
if (image.dtype === 'float32') {
56+
if (value < 0 || value > 1) {
57+
throw new Error(
58+
`Tensor values for a float32 Tensor must be in the ` +
59+
`range [0 - 1] but encountered ${value}.`);
60+
}
61+
} else if (image.dtype === 'int32') {
62+
if (value < 0 || value > 255) {
63+
throw new Error(
64+
`Tensor values for a int32 Tensor must be in the ` +
65+
`range [0 - 255] but encountered ${value}.`);
66+
}
67+
}
68+
69+
if (depth === 1) {
70+
rgba[0] = value * multiplier;
71+
rgba[1] = value * multiplier;
72+
rgba[2] = value * multiplier;
73+
} else {
74+
rgba[d] = value * multiplier;
75+
}
76+
}
77+
78+
const j = i * 4;
79+
bytes[j + 0] = Math.round(rgba[0]);
80+
bytes[j + 1] = Math.round(rgba[1]);
81+
bytes[j + 2] = Math.round(rgba[2]);
82+
bytes[j + 3] = Math.round(rgba[3]);
83+
}
84+
85+
canvas.width = width;
86+
canvas.height = height;
87+
const imageData = new ImageData(bytes, width, height);
88+
ctx.putImageData(imageData, 0, 0);
89+
return image;
90+
}
91+
92+
export const drawConfig: KernelConfig = {
93+
kernelName: Draw,
94+
backendName: 'cpu',
95+
kernelFunc: draw as unknown as KernelFunc
96+
};

tfjs-backend-cpu/src/register_all_kernels.ts

+2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ import {diagConfig} from './kernels/Diag';
7070
import {dilation2DConfig} from './kernels/Dilation2D';
7171
import {dilation2DBackpropFilterConfig} from './kernels/Dilation2DBackpropFilter';
7272
import {dilation2DBackpropInputConfig} from './kernels/Dilation2DBackpropInput';
73+
import {drawConfig} from './kernels/Draw';
7374
import {einsumConfig} from './kernels/Einsum';
7475
import {eluConfig} from './kernels/Elu';
7576
import {eluGradConfig} from './kernels/EluGrad';
@@ -244,6 +245,7 @@ const kernelConfigs: KernelConfig[] = [
244245
dilation2DConfig,
245246
dilation2DBackpropFilterConfig,
246247
dilation2DBackpropInputConfig,
248+
drawConfig,
247249
einsumConfig,
248250
eluConfig,
249251
eluGradConfig,

tfjs-backend-webgl/src/setup_test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ const customInclude = (testName: string) => {
3636
'isBrowser: false', 'dilation gradient',
3737
'throws when index is out of bound',
3838
// otsu tests for threshold op is failing on windows
39-
'method otsu'
39+
'method otsu', 'Draw on 2d context'
4040
];
4141
for (const subStr of toExclude) {
4242
if (testName.includes(subStr)) {

tfjs-backend-webgpu/src/setup_test.ts

+6
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@ const TEST_FILTERS: TestFilter[] = [
124124
'canvas and image match', // Failing on Linux
125125
],
126126
},
127+
{
128+
startsWith: 'Draw',
129+
excludes: [
130+
'on 2d context',
131+
]
132+
},
127133
{
128134
startsWith: 'sign ',
129135
excludes: [

tfjs-core/src/kernel_names.ts

+8-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import {NamedTensorInfoMap} from './kernel_registry';
2222
import {ExplicitPadding} from './ops/conv_util';
2323
import {Activation} from './ops/fused_types';
2424
import {TensorInfo} from './tensor_info';
25-
import {DataType, PixelData} from './types';
25+
import {DataType, DrawOptions, PixelData} from './types';
2626

2727
export const Abs = 'Abs';
2828
export type AbsInputs = UnaryInputs;
@@ -335,6 +335,13 @@ export const Dilation2DBackpropFilter = 'Dilation2DBackpropFilter';
335335
export type Dilation2DBackpropFilterInputs =
336336
Pick<NamedTensorInfoMap, 'x'|'filter'|'dy'>;
337337

338+
export const Draw = 'Draw';
339+
export type DrawInputs = Pick<NamedTensorInfoMap, 'image'>;
340+
export interface DrawAttrs {
341+
canvas: HTMLCanvasElement;
342+
options?: DrawOptions;
343+
}
344+
338345
export const RealDiv = 'RealDiv';
339346
export type RealDivInputs = BinaryInputs;
340347

tfjs-core/src/ops/browser.ts

+81-22
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,19 @@
1717

1818
import {ENGINE} from '../engine';
1919
import {env} from '../environment';
20-
import {FromPixels, FromPixelsAttrs, FromPixelsInputs} from '../kernel_names';
20+
import {Draw, DrawAttrs, DrawInputs, FromPixels, FromPixelsAttrs, FromPixelsInputs} from '../kernel_names';
2121
import {getKernel, NamedAttrMap} from '../kernel_registry';
2222
import {Tensor, Tensor2D, Tensor3D} from '../tensor';
2323
import {NamedTensorMap} from '../tensor_types';
2424
import {convertToTensor} from '../tensor_util_env';
25-
import {PixelData, TensorLike} from '../types';
25+
import {DrawOptions, ImageOptions, PixelData, TensorLike} from '../types';
2626

2727
import {cast} from './cast';
2828
import {op} from './operation';
2929
import {tensor3d} from './tensor3d';
3030

3131
let fromPixels2DContext: CanvasRenderingContext2D;
32+
let hasToPixelsWarned = false;
3233

3334
/**
3435
* Creates a `tf.Tensor` from an image.
@@ -145,9 +146,8 @@ function fromPixels_(
145146
'Reason: OffscreenCanvas Context2D rendering is not supported.');
146147
}
147148
} else {
148-
fromPixels2DContext =
149-
document.createElement('canvas').getContext(
150-
'2d', {willReadFrequently: true});
149+
fromPixels2DContext = document.createElement('canvas').getContext(
150+
'2d', {willReadFrequently: true});
151151
}
152152
}
153153
fromPixels2DContext.canvas.width = width;
@@ -269,6 +269,33 @@ export async function fromPixelsAsync(
269269
return fromPixels_(inputs, numChannels);
270270
}
271271

272+
function validateImgTensor(img: Tensor2D|Tensor3D) {
273+
if (img.rank !== 2 && img.rank !== 3) {
274+
throw new Error(
275+
`toPixels only supports rank 2 or 3 tensors, got rank ${img.rank}.`);
276+
}
277+
const depth = img.rank === 2 ? 1 : img.shape[2];
278+
279+
if (depth > 4 || depth === 2) {
280+
throw new Error(
281+
`toPixels only supports depth of size ` +
282+
`1, 3 or 4 but got ${depth}`);
283+
}
284+
285+
if (img.dtype !== 'float32' && img.dtype !== 'int32') {
286+
throw new Error(
287+
`Unsupported type for toPixels: ${img.dtype}.` +
288+
` Please use float32 or int32 tensors.`);
289+
}
290+
}
291+
292+
function validateImageOptions(imageOptions: ImageOptions) {
293+
const alpha = imageOptions ?.alpha || 1;
294+
if (alpha > 1 || alpha < 0) {
295+
throw new Error(`Alpha value ${alpha} is suppoed to be in range [0 - 1].`);
296+
}
297+
}
298+
272299
/**
273300
* Draws a `tf.Tensor` of pixel values to a byte array or optionally a
274301
* canvas.
@@ -299,25 +326,10 @@ export async function toPixels(
299326
$img = cast(originalImgTensor, 'int32');
300327
originalImgTensor.dispose();
301328
}
302-
if ($img.rank !== 2 && $img.rank !== 3) {
303-
throw new Error(
304-
`toPixels only supports rank 2 or 3 tensors, got rank ${$img.rank}.`);
305-
}
329+
validateImgTensor($img);
330+
306331
const [height, width] = $img.shape.slice(0, 2);
307332
const depth = $img.rank === 2 ? 1 : $img.shape[2];
308-
309-
if (depth > 4 || depth === 2) {
310-
throw new Error(
311-
`toPixels only supports depth of size ` +
312-
`1, 3 or 4 but got ${depth}`);
313-
}
314-
315-
if ($img.dtype !== 'float32' && $img.dtype !== 'int32') {
316-
throw new Error(
317-
`Unsupported type for toPixels: ${$img.dtype}.` +
318-
` Please use float32 or int32 tensors.`);
319-
}
320-
321333
const data = await $img.data();
322334
const multiplier = $img.dtype === 'float32' ? 255 : 1;
323335
const bytes = new Uint8ClampedArray(width * height * 4);
@@ -359,6 +371,13 @@ export async function toPixels(
359371
}
360372

361373
if (canvas != null) {
374+
if (!hasToPixelsWarned) {
375+
console.warn(
376+
'tf.browser.toPixels is not efficient to draw tensor on canvas. ' +
377+
'Please try tf.browser.draw instead.');
378+
hasToPixelsWarned = true;
379+
}
380+
362381
canvas.width = width;
363382
canvas.height = height;
364383
const ctx = canvas.getContext('2d');
@@ -371,4 +390,44 @@ export async function toPixels(
371390
return bytes;
372391
}
373392

393+
/**
394+
* Draws a `tf.Tensor` to a canvas.
395+
*
396+
* When the dtype of the input is 'float32', we assume values in the range
397+
* [0-1]. Otherwise, when input is 'int32', we assume values in the range
398+
* [0-255].
399+
*
400+
* @param image The tensor to draw on the canvas. Must match one of
401+
* these shapes:
402+
* - Rank-2 with shape `[height, width`]: Drawn as grayscale.
403+
* - Rank-3 with shape `[height, width, 1]`: Drawn as grayscale.
404+
* - Rank-3 with shape `[height, width, 3]`: Drawn as RGB with alpha set in
405+
* `imageOptions` (defaults to 1, which is opaque).
406+
* - Rank-3 with shape `[height, width, 4]`: Drawn as RGBA.
407+
* @param canvas The canvas to draw to.
408+
* @param options The configuration arguments for image to be drawn and the
409+
* canvas to draw to.
410+
*
411+
* @doc {heading: 'Browser', namespace: 'browser'}
412+
*/
413+
export function draw(
414+
image: Tensor2D|Tensor3D|TensorLike, canvas: HTMLCanvasElement,
415+
options?: DrawOptions): void {
416+
let $img = convertToTensor(image, 'img', 'draw');
417+
if (!(image instanceof Tensor)) {
418+
// Assume int32 if user passed a native array.
419+
const originalImgTensor = $img;
420+
$img = cast(originalImgTensor, 'int32');
421+
originalImgTensor.dispose();
422+
}
423+
validateImgTensor($img);
424+
validateImageOptions(options?.imageOptions);
425+
426+
const inputs: DrawInputs = {image: $img};
427+
const attrs: DrawAttrs = {canvas, options};
428+
ENGINE.runKernel(
429+
Draw, inputs as unknown as NamedTensorMap,
430+
attrs as unknown as NamedAttrMap);
431+
}
432+
374433
export const fromPixels = /* @__PURE__ */ op({fromPixels_});

0 commit comments

Comments
 (0)