17
17
18
18
import { ENGINE } from '../engine' ;
19
19
import { env } from '../environment' ;
20
- import { FromPixels , FromPixelsAttrs , FromPixelsInputs } from '../kernel_names' ;
20
+ import { Draw , DrawAttrs , DrawInputs , FromPixels , FromPixelsAttrs , FromPixelsInputs } from '../kernel_names' ;
21
21
import { getKernel , NamedAttrMap } from '../kernel_registry' ;
22
22
import { Tensor , Tensor2D , Tensor3D } from '../tensor' ;
23
23
import { NamedTensorMap } from '../tensor_types' ;
24
24
import { convertToTensor } from '../tensor_util_env' ;
25
- import { PixelData , TensorLike } from '../types' ;
25
+ import { DrawOptions , ImageOptions , PixelData , TensorLike } from '../types' ;
26
26
27
27
import { cast } from './cast' ;
28
28
import { op } from './operation' ;
29
29
import { tensor3d } from './tensor3d' ;
30
30
31
31
let fromPixels2DContext : CanvasRenderingContext2D ;
32
+ let hasToPixelsWarned = false ;
32
33
33
34
/**
34
35
* Creates a `tf.Tensor` from an image.
@@ -145,9 +146,8 @@ function fromPixels_(
145
146
'Reason: OffscreenCanvas Context2D rendering is not supported.' ) ;
146
147
}
147
148
} else {
148
- fromPixels2DContext =
149
- document . createElement ( 'canvas' ) . getContext (
150
- '2d' , { willReadFrequently : true } ) ;
149
+ fromPixels2DContext = document . createElement ( 'canvas' ) . getContext (
150
+ '2d' , { willReadFrequently : true } ) ;
151
151
}
152
152
}
153
153
fromPixels2DContext . canvas . width = width ;
@@ -269,6 +269,33 @@ export async function fromPixelsAsync(
269
269
return fromPixels_ ( inputs , numChannels ) ;
270
270
}
271
271
272
+ function validateImgTensor ( img : Tensor2D | Tensor3D ) {
273
+ if ( img . rank !== 2 && img . rank !== 3 ) {
274
+ throw new Error (
275
+ `toPixels only supports rank 2 or 3 tensors, got rank ${ img . rank } .` ) ;
276
+ }
277
+ const depth = img . rank === 2 ? 1 : img . shape [ 2 ] ;
278
+
279
+ if ( depth > 4 || depth === 2 ) {
280
+ throw new Error (
281
+ `toPixels only supports depth of size ` +
282
+ `1, 3 or 4 but got ${ depth } ` ) ;
283
+ }
284
+
285
+ if ( img . dtype !== 'float32' && img . dtype !== 'int32' ) {
286
+ throw new Error (
287
+ `Unsupported type for toPixels: ${ img . dtype } .` +
288
+ ` Please use float32 or int32 tensors.` ) ;
289
+ }
290
+ }
291
+
292
+ function validateImageOptions ( imageOptions : ImageOptions ) {
293
+ const alpha = imageOptions ?. alpha || 1 ;
294
+ if ( alpha > 1 || alpha < 0 ) {
295
+ throw new Error ( `Alpha value ${ alpha } is suppoed to be in range [0 - 1].` ) ;
296
+ }
297
+ }
298
+
272
299
/**
273
300
* Draws a `tf.Tensor` of pixel values to a byte array or optionally a
274
301
* canvas.
@@ -299,25 +326,10 @@ export async function toPixels(
299
326
$img = cast ( originalImgTensor , 'int32' ) ;
300
327
originalImgTensor . dispose ( ) ;
301
328
}
302
- if ( $img . rank !== 2 && $img . rank !== 3 ) {
303
- throw new Error (
304
- `toPixels only supports rank 2 or 3 tensors, got rank ${ $img . rank } .` ) ;
305
- }
329
+ validateImgTensor ( $img ) ;
330
+
306
331
const [ height , width ] = $img . shape . slice ( 0 , 2 ) ;
307
332
const depth = $img . rank === 2 ? 1 : $img . shape [ 2 ] ;
308
-
309
- if ( depth > 4 || depth === 2 ) {
310
- throw new Error (
311
- `toPixels only supports depth of size ` +
312
- `1, 3 or 4 but got ${ depth } ` ) ;
313
- }
314
-
315
- if ( $img . dtype !== 'float32' && $img . dtype !== 'int32' ) {
316
- throw new Error (
317
- `Unsupported type for toPixels: ${ $img . dtype } .` +
318
- ` Please use float32 or int32 tensors.` ) ;
319
- }
320
-
321
333
const data = await $img . data ( ) ;
322
334
const multiplier = $img . dtype === 'float32' ? 255 : 1 ;
323
335
const bytes = new Uint8ClampedArray ( width * height * 4 ) ;
@@ -359,6 +371,13 @@ export async function toPixels(
359
371
}
360
372
361
373
if ( canvas != null ) {
374
+ if ( ! hasToPixelsWarned ) {
375
+ console . warn (
376
+ 'tf.browser.toPixels is not efficient to draw tensor on canvas. ' +
377
+ 'Please try tf.browser.draw instead.' ) ;
378
+ hasToPixelsWarned = true ;
379
+ }
380
+
362
381
canvas . width = width ;
363
382
canvas . height = height ;
364
383
const ctx = canvas . getContext ( '2d' ) ;
@@ -371,4 +390,44 @@ export async function toPixels(
371
390
return bytes ;
372
391
}
373
392
393
+ /**
394
+ * Draws a `tf.Tensor` to a canvas.
395
+ *
396
+ * When the dtype of the input is 'float32', we assume values in the range
397
+ * [0-1]. Otherwise, when input is 'int32', we assume values in the range
398
+ * [0-255].
399
+ *
400
+ * @param image The tensor to draw on the canvas. Must match one of
401
+ * these shapes:
402
+ * - Rank-2 with shape `[height, width`]: Drawn as grayscale.
403
+ * - Rank-3 with shape `[height, width, 1]`: Drawn as grayscale.
404
+ * - Rank-3 with shape `[height, width, 3]`: Drawn as RGB with alpha set in
405
+ * `imageOptions` (defaults to 1, which is opaque).
406
+ * - Rank-3 with shape `[height, width, 4]`: Drawn as RGBA.
407
+ * @param canvas The canvas to draw to.
408
+ * @param options The configuration arguments for image to be drawn and the
409
+ * canvas to draw to.
410
+ *
411
+ * @doc {heading: 'Browser', namespace: 'browser'}
412
+ */
413
+ export function draw (
414
+ image : Tensor2D | Tensor3D | TensorLike , canvas : HTMLCanvasElement ,
415
+ options ?: DrawOptions ) : void {
416
+ let $img = convertToTensor ( image , 'img' , 'draw' ) ;
417
+ if ( ! ( image instanceof Tensor ) ) {
418
+ // Assume int32 if user passed a native array.
419
+ const originalImgTensor = $img ;
420
+ $img = cast ( originalImgTensor , 'int32' ) ;
421
+ originalImgTensor . dispose ( ) ;
422
+ }
423
+ validateImgTensor ( $img ) ;
424
+ validateImageOptions ( options ?. imageOptions ) ;
425
+
426
+ const inputs : DrawInputs = { image : $img } ;
427
+ const attrs : DrawAttrs = { canvas, options} ;
428
+ ENGINE . runKernel (
429
+ Draw , inputs as unknown as NamedTensorMap ,
430
+ attrs as unknown as NamedAttrMap ) ;
431
+ }
432
+
374
433
export const fromPixels = /* @__PURE__ */ op ( { fromPixels_} ) ;
0 commit comments