Skip to content

Commit 130a56d

Browse files
Add StaticRegexReplace Op (#7379)
1 parent a09edd1 commit 130a56d

File tree

23 files changed

+369
-55
lines changed

23 files changed

+369
-55
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {KernelConfig, StaticRegexReplace, StaticRegexReplaceAttrs} from '@tensorflow/tfjs-core';
19+
import {createSimpleUnaryImpl} from '../utils/unary_impl';
20+
import {unaryKernelFuncFromImpl} from '../utils/unary_utils';
21+
22+
export const staticRegexReplaceImpl = createSimpleUnaryImpl<string,
23+
string>((x: string, attrs) => {
24+
const {pattern, replaceGlobal, rewrite} =
25+
attrs as unknown as StaticRegexReplaceAttrs;
26+
// TODO(mattSoulanille): Don't create a regex each time.
27+
return x.replace(new RegExp(pattern, replaceGlobal ? 'g' : ''), rewrite);
28+
});
29+
30+
const staticRegexReplace =
31+
unaryKernelFuncFromImpl(StaticRegexReplace, staticRegexReplaceImpl);
32+
33+
export const staticRegexReplaceConfig: KernelConfig = {
34+
kernelName: StaticRegexReplace,
35+
backendName: 'cpu',
36+
kernelFunc: staticRegexReplace,
37+
};

tfjs-backend-cpu/src/register_all_kernels.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ import {splitVConfig} from './kernels/SplitV';
170170
import {sqrtConfig} from './kernels/Sqrt';
171171
import {squareConfig} from './kernels/Square';
172172
import {squaredDifferenceConfig} from './kernels/SquaredDifference';
173+
import {staticRegexReplaceConfig} from './kernels/StaticRegexReplace';
173174
import {stepConfig} from './kernels/Step';
174175
import {stridedSliceConfig} from './kernels/StridedSlice';
175176
import {stringNGramsConfig} from './kernels/StringNGrams';
@@ -342,6 +343,7 @@ const kernelConfigs: KernelConfig[] = [
342343
sqrtConfig,
343344
squareConfig,
344345
squaredDifferenceConfig,
346+
staticRegexReplaceConfig,
345347
stepConfig,
346348
stridedSliceConfig,
347349
stringNGramsConfig,

tfjs-backend-cpu/src/shared.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ export {sparseReshapeImpl} from './kernels/SparseReshape_impl';
5555
export {sparseSegmentReductionImpl} from './kernels/SparseSegmentReduction_impl';
5656
export {sqrtImpl} from './kernels/Sqrt';
5757
export {squaredDifferenceImpl} from './kernels/SquaredDifference';
58+
export {staticRegexReplaceImpl} from './kernels/StaticRegexReplace';
5859
export {stridedSliceImpl} from './kernels/StridedSlice_impl';
5960
export {stringNGramsImpl} from './kernels/StringNGrams_impl';
6061
export {stringSplitImpl} from './kernels/StringSplit_impl';

tfjs-backend-cpu/src/utils/unary_impl.ts

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,19 @@
1515
* =============================================================================
1616
*/
1717

18-
import {NumericDataType, util} from '@tensorflow/tfjs-core';
18+
import {util} from '@tensorflow/tfjs-core';
1919

2020
import {SimpleUnaryImpl, SimpleUnaryOperation} from './unary_types';
2121

2222
/**
2323
* Template that creates implementation for unary op.
2424
*/
25-
export function createSimpleUnaryImpl(op: SimpleUnaryOperation):
26-
SimpleUnaryImpl {
25+
export function createSimpleUnaryImpl<I extends number | string = number,
26+
O extends number | string = number>(op: SimpleUnaryOperation<I, O>):
27+
SimpleUnaryImpl<I, O> {
2728
return (values, dtype, attrs) => {
2829
const newValues =
29-
util.getTypedArrayFromDType(dtype as NumericDataType, values.length);
30+
util.getArrayFromDType(dtype, values.length);
3031
for (let i = 0; i < values.length; ++i) {
3132
newValues[i] = op(values[i], attrs);
3233
}

tfjs-backend-cpu/src/utils/unary_types.ts

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515
* =============================================================================
1616
*/
1717

18-
import {DataType, NamedAttrMap, TypedArray} from '@tensorflow/tfjs-core';
18+
import {DataTypeFor, DataTypeMap, NamedAttrMap} from '@tensorflow/tfjs-core';
1919

20-
export type SimpleUnaryOperation = (x: number, attrs?: NamedAttrMap) => number;
21-
export type SimpleUnaryImpl =
22-
(values: TypedArray, dtype: DataType, attrs?: NamedAttrMap) => TypedArray;
20+
export type SimpleUnaryOperation<I extends number | string = number,
21+
O extends number | string = number> = (x: I, attrs?: NamedAttrMap) => O;
22+
23+
export type SimpleUnaryImpl<I extends number | string = number | string,
24+
O extends number | string = number | string> =
25+
(values: ArrayLike<I>, dtype: DataTypeFor<O>,
26+
attrs?: NamedAttrMap) => DataTypeMap[DataTypeFor<O>];

tfjs-backend-cpu/src/utils/unary_utils.ts

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
* =============================================================================
1616
*/
1717

18-
import {DataType, KernelFunc, TypedArray, UnaryInputs, util} from '@tensorflow/tfjs-core';
18+
import {backend_util, DataTypeFor, KernelFunc, UnaryInputs} from '@tensorflow/tfjs-core';
1919

2020
import {MathBackendCPU} from '../backend_cpu';
2121
import {assertNotComplex} from '../cpu_util';
22+
import {createSimpleUnaryImpl} from './unary_impl';
2223

2324
import {SimpleUnaryImpl, SimpleUnaryOperation} from './unary_types';
2425

@@ -30,25 +31,14 @@ import {SimpleUnaryImpl, SimpleUnaryOperation} from './unary_types';
3031
* result has the same dtype as the input. This is mainly used in certain
3132
* kernels that return bool type, such as isFinite, isInf, etc.
3233
*/
33-
export function unaryKernelFunc(
34-
name: string, op: SimpleUnaryOperation, dtype?: DataType): KernelFunc {
35-
return ({inputs, attrs, backend}) => {
36-
const {x} = inputs as UnaryInputs;
37-
assertNotComplex(x, name);
38-
if (x.dtype === 'string' || dtype === 'string') {
39-
throw new Error('unaryKernelFunc does not support string input/output');
40-
}
34+
export function unaryKernelFunc<I extends number | string = number,
35+
O extends number | string = number>(
36+
name: string, op: SimpleUnaryOperation<I, O>,
37+
dtype?: DataTypeFor<O>): KernelFunc {
4138

42-
const cpuBackend = backend as MathBackendCPU;
43-
const values = cpuBackend.data.get(x.dataId).values as TypedArray;
44-
const xSize = util.sizeFromShape(x.shape);
45-
const $dtype = dtype || x.dtype;
46-
const newValues = util.getArrayFromDType($dtype, xSize);
47-
for (let i = 0; i < xSize; ++i) {
48-
newValues[i] = op(values[i], attrs);
49-
}
50-
return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
51-
};
39+
const impl = createSimpleUnaryImpl<I, O>(op);
40+
41+
return unaryKernelFuncFromImpl<I, O>(name, impl, dtype);
5242
}
5343

5444
/**
@@ -60,19 +50,30 @@ export function unaryKernelFunc(
6050
* result has the same dtype as the input. This is mainly used in certain
6151
* kernels that return bool type, such as isFinite, isInf, etc.
6252
*/
63-
export function unaryKernelFuncFromImpl(
64-
name: string, unaryImpl: SimpleUnaryImpl, dtype?: DataType): KernelFunc {
53+
export function unaryKernelFuncFromImpl<I extends number | string = number,
54+
O extends number | string = number>(
55+
name: string, unaryImpl: SimpleUnaryImpl<I, O>,
56+
dtype?: DataTypeFor<O>): KernelFunc {
57+
6558
return ({inputs, attrs, backend}) => {
6659
const {x} = inputs as UnaryInputs;
6760
assertNotComplex(x, name);
68-
if (x.dtype === 'string' || dtype === 'string') {
69-
throw new Error('unaryKernelFunc does not support string input/output');
70-
}
7161

7262
const cpuBackend = backend as MathBackendCPU;
73-
const values = cpuBackend.data.get(x.dataId).values as TypedArray;
74-
const $dtype = dtype || x.dtype;
75-
const newValues = unaryImpl(values, $dtype, attrs);
63+
const values = cpuBackend.data.get(x.dataId).values;
64+
let decoded: ArrayLike<I>;
65+
if (x.dtype === 'string') {
66+
if (!Array.isArray(values)) {
67+
throw new Error('String tensor\'s value was not an instance of Array');
68+
}
69+
decoded = backend_util.fromUint8ToStringArray(values) as unknown as
70+
ArrayLike<I>;
71+
} else {
72+
decoded = values as unknown as ArrayLike<I>;
73+
}
74+
75+
const $dtype = dtype || x.dtype as DataTypeFor<O>;
76+
const newValues = unaryImpl(decoded, $dtype, attrs);
7677
return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
7778
};
7879
}

tfjs-backend-webgl/src/kernel_utils/kernel_funcs_utils.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* =============================================================================
1616
*/
1717

18-
import {backend_util, BinaryInputs, DataType, env, KernelFunc, TypedArray, UnaryInputs, upcastType} from '@tensorflow/tfjs-core';
18+
import { backend_util, BinaryInputs, DataType, env, KernelFunc, TypedArray, UnaryInputs, upcastType} from '@tensorflow/tfjs-core';
1919

2020
import {MathBackendWebGL} from '../backend_webgl';
2121
import {BinaryOpProgram} from '../binaryop_gpu';
@@ -36,7 +36,7 @@ type UnaryKernelFuncConfig = {
3636
opSnippet: string,
3737
packedOpSnippet?: string,
3838
cpuKernelImpl?: SimpleUnaryKernelImplCPU,
39-
dtype?: DataType
39+
dtype?: DataType,
4040
};
4141

4242
/**

tfjs-backend-webgl/src/kernel_utils/shared.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ const {
6666
sparseReshapeImpl: sparseReshapeImplCPU,
6767
sparseSegmentReductionImpl: sparseSegmentReductionImplCPU,
6868
sqrtImpl: sqrtImplCPU,
69+
staticRegexReplaceImpl: staticRegexReplaceImplCPU,
6970
stridedSliceImpl: stridedSliceImplCPU,
7071
stringNGramsImpl: stringNGramsImplCPU,
7172
stringSplitImpl: stringSplitImplCPU,
@@ -114,6 +115,7 @@ export {
114115
sparseReshapeImplCPU,
115116
sparseSegmentReductionImplCPU,
116117
sqrtImplCPU,
118+
staticRegexReplaceImplCPU,
117119
stridedSliceImplCPU,
118120
stringNGramsImplCPU,
119121
stringSplitImplCPU,
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {backend_util, KernelConfig, KernelFunc, NamedAttrMap, StaticRegexReplace, StaticRegexReplaceAttrs, StaticRegexReplaceInputs, TensorInfo} from '@tensorflow/tfjs-core';
19+
import {MathBackendWebGL} from '../backend_webgl';
20+
import {staticRegexReplaceImplCPU} from '../kernel_utils/shared';
21+
22+
export function staticRegexReplace(args: {
23+
inputs: StaticRegexReplaceInputs,
24+
backend: MathBackendWebGL,
25+
attrs: StaticRegexReplaceAttrs,
26+
}): TensorInfo {
27+
const {inputs, backend, attrs} = args;
28+
const {x} = inputs;
29+
30+
if (x.dtype !== 'string') {
31+
throw new Error('Input must be of datatype string');
32+
}
33+
34+
const $x = backend.readSync(x.dataId) as Uint8Array[];
35+
36+
const stringInput = backend_util.fromUint8ToStringArray($x);
37+
const output = staticRegexReplaceImplCPU(stringInput, 'string',
38+
attrs as unknown as NamedAttrMap);
39+
40+
return backend.makeTensorInfo(x.shape, 'string', output);
41+
}
42+
43+
export const staticRegexReplaceConfig: KernelConfig = {
44+
kernelName: StaticRegexReplace,
45+
backendName: 'webgl',
46+
kernelFunc: staticRegexReplace as unknown as KernelFunc,
47+
};

tfjs-backend-webgl/src/register_all_kernels.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ import {splitVConfig} from './kernels/SplitV';
166166
import {sqrtConfig} from './kernels/Sqrt';
167167
import {squareConfig} from './kernels/Square';
168168
import {squaredDifferenceConfig} from './kernels/SquaredDifference';
169+
import {staticRegexReplaceConfig} from './kernels/StaticRegexReplace';
169170
import {stepConfig} from './kernels/Step';
170171
import {stridedSliceConfig} from './kernels/StridedSlice';
171172
import {stringNGramsConfig} from './kernels/StringNGrams';
@@ -337,6 +338,7 @@ const kernelConfigs: KernelConfig[] = [
337338
sqrtConfig,
338339
squareConfig,
339340
squaredDifferenceConfig,
341+
staticRegexReplaceConfig,
340342
stepConfig,
341343
stridedSliceConfig,
342344
stringNGramsConfig,

0 commit comments

Comments
 (0)