15
15
* =============================================================================
16
16
*/
17
17
18
- import { DataType , KernelFunc , TypedArray , UnaryInputs , util } from '@tensorflow/tfjs-core' ;
18
+ import { backend_util , DataTypeFor , KernelFunc , UnaryInputs } from '@tensorflow/tfjs-core' ;
19
19
20
20
import { MathBackendCPU } from '../backend_cpu' ;
21
21
import { assertNotComplex } from '../cpu_util' ;
22
+ import { createSimpleUnaryImpl } from './unary_impl' ;
22
23
23
24
import { SimpleUnaryImpl , SimpleUnaryOperation } from './unary_types' ;
24
25
@@ -30,25 +31,14 @@ import {SimpleUnaryImpl, SimpleUnaryOperation} from './unary_types';
30
31
* result has the same dtype as the input. This is mainly used in certain
31
32
* kernels that return bool type, such as isFinite, isInf, etc.
32
33
*/
33
- export function unaryKernelFunc (
34
- name : string , op : SimpleUnaryOperation , dtype ?: DataType ) : KernelFunc {
35
- return ( { inputs, attrs, backend} ) => {
36
- const { x} = inputs as UnaryInputs ;
37
- assertNotComplex ( x , name ) ;
38
- if ( x . dtype === 'string' || dtype === 'string' ) {
39
- throw new Error ( 'unaryKernelFunc does not support string input/output' ) ;
40
- }
34
+ export function unaryKernelFunc < I extends number | string = number ,
35
+ O extends number | string = number > (
36
+ name : string , op : SimpleUnaryOperation < I , O > ,
37
+ dtype ?: DataTypeFor < O > ) : KernelFunc {
41
38
42
- const cpuBackend = backend as MathBackendCPU ;
43
- const values = cpuBackend . data . get ( x . dataId ) . values as TypedArray ;
44
- const xSize = util . sizeFromShape ( x . shape ) ;
45
- const $dtype = dtype || x . dtype ;
46
- const newValues = util . getArrayFromDType ( $dtype , xSize ) ;
47
- for ( let i = 0 ; i < xSize ; ++ i ) {
48
- newValues [ i ] = op ( values [ i ] , attrs ) ;
49
- }
50
- return cpuBackend . makeTensorInfo ( x . shape , $dtype , newValues ) ;
51
- } ;
39
+ const impl = createSimpleUnaryImpl < I , O > ( op ) ;
40
+
41
+ return unaryKernelFuncFromImpl < I , O > ( name , impl , dtype ) ;
52
42
}
53
43
54
44
/**
@@ -60,19 +50,30 @@ export function unaryKernelFunc(
60
50
* result has the same dtype as the input. This is mainly used in certain
61
51
* kernels that return bool type, such as isFinite, isInf, etc.
62
52
*/
63
- export function unaryKernelFuncFromImpl (
64
- name : string , unaryImpl : SimpleUnaryImpl , dtype ?: DataType ) : KernelFunc {
53
+ export function unaryKernelFuncFromImpl < I extends number | string = number ,
54
+ O extends number | string = number > (
55
+ name : string , unaryImpl : SimpleUnaryImpl < I , O > ,
56
+ dtype ?: DataTypeFor < O > ) : KernelFunc {
57
+
65
58
return ( { inputs, attrs, backend} ) => {
66
59
const { x} = inputs as UnaryInputs ;
67
60
assertNotComplex ( x , name ) ;
68
- if ( x . dtype === 'string' || dtype === 'string' ) {
69
- throw new Error ( 'unaryKernelFunc does not support string input/output' ) ;
70
- }
71
61
72
62
const cpuBackend = backend as MathBackendCPU ;
73
- const values = cpuBackend . data . get ( x . dataId ) . values as TypedArray ;
74
- const $dtype = dtype || x . dtype ;
75
- const newValues = unaryImpl ( values , $dtype , attrs ) ;
63
+ const values = cpuBackend . data . get ( x . dataId ) . values ;
64
+ let decoded : ArrayLike < I > ;
65
+ if ( x . dtype === 'string' ) {
66
+ if ( ! Array . isArray ( values ) ) {
67
+ throw new Error ( 'String tensor\'s value was not an instance of Array' ) ;
68
+ }
69
+ decoded = backend_util . fromUint8ToStringArray ( values ) as unknown as
70
+ ArrayLike < I > ;
71
+ } else {
72
+ decoded = values as unknown as ArrayLike < I > ;
73
+ }
74
+
75
+ const $dtype = dtype || x . dtype as DataTypeFor < O > ;
76
+ const newValues = unaryImpl ( decoded , $dtype , attrs ) ;
76
77
return cpuBackend . makeTensorInfo ( x . shape , $dtype , newValues ) ;
77
78
} ;
78
79
}
0 commit comments