1515 * =============================================================================
1616 */
1717
18- import { DataType , KernelFunc , TypedArray , UnaryInputs , util } from '@tensorflow/tfjs-core' ;
18+ import { backend_util , DataTypeFor , KernelFunc , UnaryInputs } from '@tensorflow/tfjs-core' ;
1919
2020import { MathBackendCPU } from '../backend_cpu' ;
2121import { assertNotComplex } from '../cpu_util' ;
22+ import { createSimpleUnaryImpl } from './unary_impl' ;
2223
2324import { SimpleUnaryImpl , SimpleUnaryOperation } from './unary_types' ;
2425
@@ -30,25 +31,14 @@ import {SimpleUnaryImpl, SimpleUnaryOperation} from './unary_types';
3031 * result has the same dtype as the input. This is mainly used in certain
3132 * kernels that return bool type, such as isFinite, isInf, etc.
3233 */
33- export function unaryKernelFunc (
34- name : string , op : SimpleUnaryOperation , dtype ?: DataType ) : KernelFunc {
35- return ( { inputs, attrs, backend} ) => {
36- const { x} = inputs as UnaryInputs ;
37- assertNotComplex ( x , name ) ;
38- if ( x . dtype === 'string' || dtype === 'string' ) {
39- throw new Error ( 'unaryKernelFunc does not support string input/output' ) ;
40- }
34+ export function unaryKernelFunc < I extends number | string = number ,
35+ O extends number | string = number > (
36+ name : string , op : SimpleUnaryOperation < I , O > ,
37+ dtype ?: DataTypeFor < O > ) : KernelFunc {
4138
42- const cpuBackend = backend as MathBackendCPU ;
43- const values = cpuBackend . data . get ( x . dataId ) . values as TypedArray ;
44- const xSize = util . sizeFromShape ( x . shape ) ;
45- const $dtype = dtype || x . dtype ;
46- const newValues = util . getArrayFromDType ( $dtype , xSize ) ;
47- for ( let i = 0 ; i < xSize ; ++ i ) {
48- newValues [ i ] = op ( values [ i ] , attrs ) ;
49- }
50- return cpuBackend . makeTensorInfo ( x . shape , $dtype , newValues ) ;
51- } ;
39+ const impl = createSimpleUnaryImpl < I , O > ( op ) ;
40+
41+ return unaryKernelFuncFromImpl < I , O > ( name , impl , dtype ) ;
5242}
5343
5444/**
@@ -60,19 +50,30 @@ export function unaryKernelFunc(
6050 * result has the same dtype as the input. This is mainly used in certain
6151 * kernels that return bool type, such as isFinite, isInf, etc.
6252 */
63- export function unaryKernelFuncFromImpl (
64- name : string , unaryImpl : SimpleUnaryImpl , dtype ?: DataType ) : KernelFunc {
53+ export function unaryKernelFuncFromImpl < I extends number | string = number ,
54+ O extends number | string = number > (
55+ name : string , unaryImpl : SimpleUnaryImpl < I , O > ,
56+ dtype ?: DataTypeFor < O > ) : KernelFunc {
57+
6558 return ( { inputs, attrs, backend} ) => {
6659 const { x} = inputs as UnaryInputs ;
6760 assertNotComplex ( x , name ) ;
68- if ( x . dtype === 'string' || dtype === 'string' ) {
69- throw new Error ( 'unaryKernelFunc does not support string input/output' ) ;
70- }
7161
7262 const cpuBackend = backend as MathBackendCPU ;
73- const values = cpuBackend . data . get ( x . dataId ) . values as TypedArray ;
74- const $dtype = dtype || x . dtype ;
75- const newValues = unaryImpl ( values , $dtype , attrs ) ;
63+ const values = cpuBackend . data . get ( x . dataId ) . values ;
64+ let decoded : ArrayLike < I > ;
65+ if ( x . dtype === 'string' ) {
66+ if ( ! Array . isArray ( values ) ) {
67+ throw new Error ( 'String tensor\'s value was not an instance of Array' ) ;
68+ }
69+ decoded = backend_util . fromUint8ToStringArray ( values ) as unknown as
70+ ArrayLike < I > ;
71+ } else {
72+ decoded = values as unknown as ArrayLike < I > ;
73+ }
74+
75+ const $dtype = dtype || x . dtype as DataTypeFor < O > ;
76+ const newValues = unaryImpl ( decoded , $dtype , attrs ) ;
7677 return cpuBackend . makeTensorInfo ( x . shape , $dtype , newValues ) ;
7778 } ;
7879}
0 commit comments