1717
1818import { ENGINE } from '../engine' ;
1919import { env } from '../environment' ;
20- import { FromPixels , FromPixelsAttrs , FromPixelsInputs } from '../kernel_names' ;
20+ import { Draw , DrawAttrs , DrawInputs , FromPixels , FromPixelsAttrs , FromPixelsInputs } from '../kernel_names' ;
2121import { getKernel , NamedAttrMap } from '../kernel_registry' ;
2222import { Tensor , Tensor2D , Tensor3D } from '../tensor' ;
2323import { NamedTensorMap } from '../tensor_types' ;
2424import { convertToTensor } from '../tensor_util_env' ;
25- import { PixelData , TensorLike } from '../types' ;
25+ import { DrawOptions , ImageOptions , PixelData , TensorLike } from '../types' ;
2626
2727import { cast } from './cast' ;
2828import { op } from './operation' ;
2929import { tensor3d } from './tensor3d' ;
3030
3131let fromPixels2DContext : CanvasRenderingContext2D ;
32+ let hasToPixelsWarned = false ;
3233
3334/**
3435 * Creates a `tf.Tensor` from an image.
@@ -145,9 +146,8 @@ function fromPixels_(
145146 'Reason: OffscreenCanvas Context2D rendering is not supported.' ) ;
146147 }
147148 } else {
148- fromPixels2DContext =
149- document . createElement ( 'canvas' ) . getContext (
150- '2d' , { willReadFrequently : true } ) ;
149+ fromPixels2DContext = document . createElement ( 'canvas' ) . getContext (
150+ '2d' , { willReadFrequently : true } ) ;
151151 }
152152 }
153153 fromPixels2DContext . canvas . width = width ;
@@ -269,6 +269,33 @@ export async function fromPixelsAsync(
269269 return fromPixels_ ( inputs , numChannels ) ;
270270}
271271
272+ function validateImgTensor ( img : Tensor2D | Tensor3D ) {
273+ if ( img . rank !== 2 && img . rank !== 3 ) {
274+ throw new Error (
275+ `toPixels only supports rank 2 or 3 tensors, got rank ${ img . rank } .` ) ;
276+ }
277+ const depth = img . rank === 2 ? 1 : img . shape [ 2 ] ;
278+
279+ if ( depth > 4 || depth === 2 ) {
280+ throw new Error (
281+ `toPixels only supports depth of size ` +
282+ `1, 3 or 4 but got ${ depth } ` ) ;
283+ }
284+
285+ if ( img . dtype !== 'float32' && img . dtype !== 'int32' ) {
286+ throw new Error (
287+ `Unsupported type for toPixels: ${ img . dtype } .` +
288+ ` Please use float32 or int32 tensors.` ) ;
289+ }
290+ }
291+
292+ function validateImageOptions ( imageOptions : ImageOptions ) {
293+ const alpha = imageOptions ?. alpha || 1 ;
294+ if ( alpha > 1 || alpha < 0 ) {
295+ throw new Error ( `Alpha value ${ alpha } is suppoed to be in range [0 - 1].` ) ;
296+ }
297+ }
298+
272299/**
273300 * Draws a `tf.Tensor` of pixel values to a byte array or optionally a
274301 * canvas.
@@ -299,25 +326,10 @@ export async function toPixels(
299326 $img = cast ( originalImgTensor , 'int32' ) ;
300327 originalImgTensor . dispose ( ) ;
301328 }
302- if ( $img . rank !== 2 && $img . rank !== 3 ) {
303- throw new Error (
304- `toPixels only supports rank 2 or 3 tensors, got rank ${ $img . rank } .` ) ;
305- }
329+ validateImgTensor ( $img ) ;
330+
306331 const [ height , width ] = $img . shape . slice ( 0 , 2 ) ;
307332 const depth = $img . rank === 2 ? 1 : $img . shape [ 2 ] ;
308-
309- if ( depth > 4 || depth === 2 ) {
310- throw new Error (
311- `toPixels only supports depth of size ` +
312- `1, 3 or 4 but got ${ depth } ` ) ;
313- }
314-
315- if ( $img . dtype !== 'float32' && $img . dtype !== 'int32' ) {
316- throw new Error (
317- `Unsupported type for toPixels: ${ $img . dtype } .` +
318- ` Please use float32 or int32 tensors.` ) ;
319- }
320-
321333 const data = await $img . data ( ) ;
322334 const multiplier = $img . dtype === 'float32' ? 255 : 1 ;
323335 const bytes = new Uint8ClampedArray ( width * height * 4 ) ;
@@ -359,6 +371,13 @@ export async function toPixels(
359371 }
360372
361373 if ( canvas != null ) {
374+ if ( ! hasToPixelsWarned ) {
375+ console . warn (
376+ 'tf.browser.toPixels is not efficient to draw tensor on canvas. ' +
377+ 'Please try tf.browser.draw instead.' ) ;
378+ hasToPixelsWarned = true ;
379+ }
380+
362381 canvas . width = width ;
363382 canvas . height = height ;
364383 const ctx = canvas . getContext ( '2d' ) ;
@@ -371,4 +390,44 @@ export async function toPixels(
371390 return bytes ;
372391}
373392
393+ /**
394+ * Draws a `tf.Tensor` to a canvas.
395+ *
396+ * When the dtype of the input is 'float32', we assume values in the range
397+ * [0-1]. Otherwise, when input is 'int32', we assume values in the range
398+ * [0-255].
399+ *
400+ * @param image The tensor to draw on the canvas. Must match one of
401+ * these shapes:
402+ * - Rank-2 with shape `[height, width`]: Drawn as grayscale.
403+ * - Rank-3 with shape `[height, width, 1]`: Drawn as grayscale.
404+ * - Rank-3 with shape `[height, width, 3]`: Drawn as RGB with alpha set in
405+ * `imageOptions` (defaults to 1, which is opaque).
406+ * - Rank-3 with shape `[height, width, 4]`: Drawn as RGBA.
407+ * @param canvas The canvas to draw to.
408+ * @param options The configuration arguments for image to be drawn and the
409+ * canvas to draw to.
410+ *
411+ * @doc {heading: 'Browser', namespace: 'browser'}
412+ */
413+ export function draw (
414+ image : Tensor2D | Tensor3D | TensorLike , canvas : HTMLCanvasElement ,
415+ options ?: DrawOptions ) : void {
416+ let $img = convertToTensor ( image , 'img' , 'draw' ) ;
417+ if ( ! ( image instanceof Tensor ) ) {
418+ // Assume int32 if user passed a native array.
419+ const originalImgTensor = $img ;
420+ $img = cast ( originalImgTensor , 'int32' ) ;
421+ originalImgTensor . dispose ( ) ;
422+ }
423+ validateImgTensor ( $img ) ;
424+ validateImageOptions ( options ?. imageOptions ) ;
425+
426+ const inputs : DrawInputs = { image : $img } ;
427+ const attrs : DrawAttrs = { canvas, options} ;
428+ ENGINE . runKernel (
429+ Draw , inputs as unknown as NamedTensorMap ,
430+ attrs as unknown as NamedAttrMap ) ;
431+ }
432+
374433export const fromPixels = /* @__PURE__ */ op ( { fromPixels_} ) ;
0 commit comments