Skip to content

Commit 130a56d

Browse files
Add StaticRegexReplace Op (#7379)
1 parent a09edd1 commit 130a56d

23 files changed

+369
-55
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {KernelConfig, StaticRegexReplace, StaticRegexReplaceAttrs} from '@tensorflow/tfjs-core';
19+
import {createSimpleUnaryImpl} from '../utils/unary_impl';
20+
import {unaryKernelFuncFromImpl} from '../utils/unary_utils';
21+
22+
export const staticRegexReplaceImpl = createSimpleUnaryImpl<string,
23+
string>((x: string, attrs) => {
24+
const {pattern, replaceGlobal, rewrite} =
25+
attrs as unknown as StaticRegexReplaceAttrs;
26+
// TODO(mattSoulanille): Don't create a regex each time.
27+
return x.replace(new RegExp(pattern, replaceGlobal ? 'g' : ''), rewrite);
28+
});
29+
30+
const staticRegexReplace =
31+
unaryKernelFuncFromImpl(StaticRegexReplace, staticRegexReplaceImpl);
32+
33+
export const staticRegexReplaceConfig: KernelConfig = {
34+
kernelName: StaticRegexReplace,
35+
backendName: 'cpu',
36+
kernelFunc: staticRegexReplace,
37+
};

tfjs-backend-cpu/src/register_all_kernels.ts

+2
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ import {splitVConfig} from './kernels/SplitV';
170170
import {sqrtConfig} from './kernels/Sqrt';
171171
import {squareConfig} from './kernels/Square';
172172
import {squaredDifferenceConfig} from './kernels/SquaredDifference';
173+
import {staticRegexReplaceConfig} from './kernels/StaticRegexReplace';
173174
import {stepConfig} from './kernels/Step';
174175
import {stridedSliceConfig} from './kernels/StridedSlice';
175176
import {stringNGramsConfig} from './kernels/StringNGrams';
@@ -342,6 +343,7 @@ const kernelConfigs: KernelConfig[] = [
342343
sqrtConfig,
343344
squareConfig,
344345
squaredDifferenceConfig,
346+
staticRegexReplaceConfig,
345347
stepConfig,
346348
stridedSliceConfig,
347349
stringNGramsConfig,

tfjs-backend-cpu/src/shared.ts

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ export {sparseReshapeImpl} from './kernels/SparseReshape_impl';
5555
export {sparseSegmentReductionImpl} from './kernels/SparseSegmentReduction_impl';
5656
export {sqrtImpl} from './kernels/Sqrt';
5757
export {squaredDifferenceImpl} from './kernels/SquaredDifference';
58+
export {staticRegexReplaceImpl} from './kernels/StaticRegexReplace';
5859
export {stridedSliceImpl} from './kernels/StridedSlice_impl';
5960
export {stringNGramsImpl} from './kernels/StringNGrams_impl';
6061
export {stringSplitImpl} from './kernels/StringSplit_impl';

tfjs-backend-cpu/src/utils/unary_impl.ts

+5-4
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,19 @@
1515
* =============================================================================
1616
*/
1717

18-
import {NumericDataType, util} from '@tensorflow/tfjs-core';
18+
import {util} from '@tensorflow/tfjs-core';
1919

2020
import {SimpleUnaryImpl, SimpleUnaryOperation} from './unary_types';
2121

2222
/**
2323
* Template that creates implementation for unary op.
2424
*/
25-
export function createSimpleUnaryImpl(op: SimpleUnaryOperation):
26-
SimpleUnaryImpl {
25+
export function createSimpleUnaryImpl<I extends number | string = number,
26+
O extends number | string = number>(op: SimpleUnaryOperation<I, O>):
27+
SimpleUnaryImpl<I, O> {
2728
return (values, dtype, attrs) => {
2829
const newValues =
29-
util.getTypedArrayFromDType(dtype as NumericDataType, values.length);
30+
util.getArrayFromDType(dtype, values.length);
3031
for (let i = 0; i < values.length; ++i) {
3132
newValues[i] = op(values[i], attrs);
3233
}

tfjs-backend-cpu/src/utils/unary_types.ts

+8-4
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515
* =============================================================================
1616
*/
1717

18-
import {DataType, NamedAttrMap, TypedArray} from '@tensorflow/tfjs-core';
18+
import {DataTypeFor, DataTypeMap, NamedAttrMap} from '@tensorflow/tfjs-core';
1919

20-
export type SimpleUnaryOperation = (x: number, attrs?: NamedAttrMap) => number;
21-
export type SimpleUnaryImpl =
22-
(values: TypedArray, dtype: DataType, attrs?: NamedAttrMap) => TypedArray;
20+
export type SimpleUnaryOperation<I extends number | string = number,
21+
O extends number | string = number> = (x: I, attrs?: NamedAttrMap) => O;
22+
23+
export type SimpleUnaryImpl<I extends number | string = number | string,
24+
O extends number | string = number | string> =
25+
(values: ArrayLike<I>, dtype: DataTypeFor<O>,
26+
attrs?: NamedAttrMap) => DataTypeMap[DataTypeFor<O>];

tfjs-backend-cpu/src/utils/unary_utils.ts

+28-27
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
* =============================================================================
1616
*/
1717

18-
import {DataType, KernelFunc, TypedArray, UnaryInputs, util} from '@tensorflow/tfjs-core';
18+
import {backend_util, DataTypeFor, KernelFunc, UnaryInputs} from '@tensorflow/tfjs-core';
1919

2020
import {MathBackendCPU} from '../backend_cpu';
2121
import {assertNotComplex} from '../cpu_util';
22+
import {createSimpleUnaryImpl} from './unary_impl';
2223

2324
import {SimpleUnaryImpl, SimpleUnaryOperation} from './unary_types';
2425

@@ -30,25 +31,14 @@ import {SimpleUnaryImpl, SimpleUnaryOperation} from './unary_types';
3031
* result has the same dtype as the input. This is mainly used in certain
3132
* kernels that return bool type, such as isFinite, isInf, etc.
3233
*/
33-
export function unaryKernelFunc(
34-
name: string, op: SimpleUnaryOperation, dtype?: DataType): KernelFunc {
35-
return ({inputs, attrs, backend}) => {
36-
const {x} = inputs as UnaryInputs;
37-
assertNotComplex(x, name);
38-
if (x.dtype === 'string' || dtype === 'string') {
39-
throw new Error('unaryKernelFunc does not support string input/output');
40-
}
34+
export function unaryKernelFunc<I extends number | string = number,
35+
O extends number | string = number>(
36+
name: string, op: SimpleUnaryOperation<I, O>,
37+
dtype?: DataTypeFor<O>): KernelFunc {
4138

42-
const cpuBackend = backend as MathBackendCPU;
43-
const values = cpuBackend.data.get(x.dataId).values as TypedArray;
44-
const xSize = util.sizeFromShape(x.shape);
45-
const $dtype = dtype || x.dtype;
46-
const newValues = util.getArrayFromDType($dtype, xSize);
47-
for (let i = 0; i < xSize; ++i) {
48-
newValues[i] = op(values[i], attrs);
49-
}
50-
return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
51-
};
39+
const impl = createSimpleUnaryImpl<I, O>(op);
40+
41+
return unaryKernelFuncFromImpl<I, O>(name, impl, dtype);
5242
}
5343

5444
/**
@@ -60,19 +50,30 @@ export function unaryKernelFunc(
6050
* result has the same dtype as the input. This is mainly used in certain
6151
* kernels that return bool type, such as isFinite, isInf, etc.
6252
*/
63-
export function unaryKernelFuncFromImpl(
64-
name: string, unaryImpl: SimpleUnaryImpl, dtype?: DataType): KernelFunc {
53+
export function unaryKernelFuncFromImpl<I extends number | string = number,
54+
O extends number | string = number>(
55+
name: string, unaryImpl: SimpleUnaryImpl<I, O>,
56+
dtype?: DataTypeFor<O>): KernelFunc {
57+
6558
return ({inputs, attrs, backend}) => {
6659
const {x} = inputs as UnaryInputs;
6760
assertNotComplex(x, name);
68-
if (x.dtype === 'string' || dtype === 'string') {
69-
throw new Error('unaryKernelFunc does not support string input/output');
70-
}
7161

7262
const cpuBackend = backend as MathBackendCPU;
73-
const values = cpuBackend.data.get(x.dataId).values as TypedArray;
74-
const $dtype = dtype || x.dtype;
75-
const newValues = unaryImpl(values, $dtype, attrs);
63+
const values = cpuBackend.data.get(x.dataId).values;
64+
let decoded: ArrayLike<I>;
65+
if (x.dtype === 'string') {
66+
if (!Array.isArray(values)) {
67+
throw new Error('String tensor\'s value was not an instance of Array');
68+
}
69+
decoded = backend_util.fromUint8ToStringArray(values) as unknown as
70+
ArrayLike<I>;
71+
} else {
72+
decoded = values as unknown as ArrayLike<I>;
73+
}
74+
75+
const $dtype = dtype || x.dtype as DataTypeFor<O>;
76+
const newValues = unaryImpl(decoded, $dtype, attrs);
7677
return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
7778
};
7879
}

tfjs-backend-webgl/src/kernel_utils/kernel_funcs_utils.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* =============================================================================
1616
*/
1717

18-
import {backend_util, BinaryInputs, DataType, env, KernelFunc, TypedArray, UnaryInputs, upcastType} from '@tensorflow/tfjs-core';
18+
import { backend_util, BinaryInputs, DataType, env, KernelFunc, TypedArray, UnaryInputs, upcastType} from '@tensorflow/tfjs-core';
1919

2020
import {MathBackendWebGL} from '../backend_webgl';
2121
import {BinaryOpProgram} from '../binaryop_gpu';
@@ -36,7 +36,7 @@ type UnaryKernelFuncConfig = {
3636
opSnippet: string,
3737
packedOpSnippet?: string,
3838
cpuKernelImpl?: SimpleUnaryKernelImplCPU,
39-
dtype?: DataType
39+
dtype?: DataType,
4040
};
4141

4242
/**

tfjs-backend-webgl/src/kernel_utils/shared.ts

+2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ const {
6666
sparseReshapeImpl: sparseReshapeImplCPU,
6767
sparseSegmentReductionImpl: sparseSegmentReductionImplCPU,
6868
sqrtImpl: sqrtImplCPU,
69+
staticRegexReplaceImpl: staticRegexReplaceImplCPU,
6970
stridedSliceImpl: stridedSliceImplCPU,
7071
stringNGramsImpl: stringNGramsImplCPU,
7172
stringSplitImpl: stringSplitImplCPU,
@@ -114,6 +115,7 @@ export {
114115
sparseReshapeImplCPU,
115116
sparseSegmentReductionImplCPU,
116117
sqrtImplCPU,
118+
staticRegexReplaceImplCPU,
117119
stridedSliceImplCPU,
118120
stringNGramsImplCPU,
119121
stringSplitImplCPU,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {backend_util, KernelConfig, KernelFunc, NamedAttrMap, StaticRegexReplace, StaticRegexReplaceAttrs, StaticRegexReplaceInputs, TensorInfo} from '@tensorflow/tfjs-core';
19+
import {MathBackendWebGL} from '../backend_webgl';
20+
import {staticRegexReplaceImplCPU} from '../kernel_utils/shared';
21+
22+
export function staticRegexReplace(args: {
23+
inputs: StaticRegexReplaceInputs,
24+
backend: MathBackendWebGL,
25+
attrs: StaticRegexReplaceAttrs,
26+
}): TensorInfo {
27+
const {inputs, backend, attrs} = args;
28+
const {x} = inputs;
29+
30+
if (x.dtype !== 'string') {
31+
throw new Error('Input must be of datatype string');
32+
}
33+
34+
const $x = backend.readSync(x.dataId) as Uint8Array[];
35+
36+
const stringInput = backend_util.fromUint8ToStringArray($x);
37+
const output = staticRegexReplaceImplCPU(stringInput, 'string',
38+
attrs as unknown as NamedAttrMap);
39+
40+
return backend.makeTensorInfo(x.shape, 'string', output);
41+
}
42+
43+
export const staticRegexReplaceConfig: KernelConfig = {
44+
kernelName: StaticRegexReplace,
45+
backendName: 'webgl',
46+
kernelFunc: staticRegexReplace as unknown as KernelFunc,
47+
};

tfjs-backend-webgl/src/register_all_kernels.ts

+2
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ import {splitVConfig} from './kernels/SplitV';
166166
import {sqrtConfig} from './kernels/Sqrt';
167167
import {squareConfig} from './kernels/Square';
168168
import {squaredDifferenceConfig} from './kernels/SquaredDifference';
169+
import {staticRegexReplaceConfig} from './kernels/StaticRegexReplace';
169170
import {stepConfig} from './kernels/Step';
170171
import {stridedSliceConfig} from './kernels/StridedSlice';
171172
import {stringNGramsConfig} from './kernels/StringNGrams';
@@ -337,6 +338,7 @@ const kernelConfigs: KernelConfig[] = [
337338
sqrtConfig,
338339
squareConfig,
339340
squaredDifferenceConfig,
341+
staticRegexReplaceConfig,
340342
stepConfig,
341343
stridedSliceConfig,
342344
stringNGramsConfig,

tfjs-converter/python/tensorflowjs/op_list/string.json

+29-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,32 @@
11
[
2+
{
3+
"tfOpName": "StaticRegexReplace",
4+
"category": "string",
5+
"inputs": [
6+
{
7+
"start": 0,
8+
"name": "input",
9+
"type": "tensor"
10+
}
11+
],
12+
"attrs": [
13+
{
14+
"tfName": "pattern",
15+
"name": "pattern",
16+
"type": "string"
17+
},
18+
{
19+
"tfName": "rewrite",
20+
"name": "rewrite",
21+
"type": "string"
22+
},
23+
{
24+
"tfName": "replace_global",
25+
"name": "replaceGlobal",
26+
"type": "bool"
27+
}
28+
]
29+
},
230
{
331
"tfOpName": "StringNGrams",
432
"category": "string",
@@ -97,4 +125,4 @@
97125
}
98126
]
99127
}
100-
]
128+
]

tfjs-converter/src/operations/executors/string_executor.ts

+8
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ export const executeOp: InternalOpExecutor =
2929
(node: Node, tensorMap: NamedTensorsMap,
3030
context: ExecutionContext, ops = tfOps): Tensor[] => {
3131
switch (node.op) {
32+
case 'StaticRegexReplace': {
33+
return [ops.string.staticRegexReplace(
34+
getParamValue('input', node, tensorMap, context) as Tensor,
35+
getParamValue('pattern', node, tensorMap, context) as string,
36+
getParamValue('rewrite', node, tensorMap, context) as string,
37+
getParamValue('replaceGlobal', node, tensorMap, context) as boolean,
38+
)];
39+
}
3240
case 'StringNGrams': {
3341
const {nGrams, nGramsSplits} = ops.string.stringNGrams(
3442
getParamValue('data', node, tensorMap, context) as Tensor1D,

tfjs-converter/src/operations/executors/string_executor_test.ts

+24
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,30 @@ describe('string', () => {
4949
});
5050

5151
describe('executeOp', () => {
52+
describe('StaticRegexReplace', () => {
53+
it('should call tfOps.string.staticRegexReplace', async () => {
54+
node.op = 'StaticRegexReplace';
55+
node.inputParams = {
56+
input: createTensorAttr(0),
57+
};
58+
node.attrParams = {
59+
pattern: createStrAttr('foo'),
60+
rewrite: createStrAttr('bar'),
61+
replaceGlobal: createBoolAttr(true),
62+
};
63+
node.inputNames = ['input'];
64+
65+
const input = [tfOps.tensor1d(['a', 'b', 'foo', 'd'])];
66+
const result = executeOp(node, {input}, context,
67+
spyOpsAsTfOps) as Tensor[];
68+
69+
expect(spyOps.string.staticRegexReplace)
70+
.toHaveBeenCalledWith(input[0], 'foo', 'bar', true);
71+
72+
test_util.expectArraysEqual(
73+
await result[0].data(), ['a', 'b', 'bar', 'd']);
74+
});
75+
});
5276
describe('StringNGrams', () => {
5377
it('should call tfOps.string.stringNGrams', async () => {
5478
node.op = 'StringNGrams';

tfjs-core/src/base.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';
5555
export {SGDOptimizer} from './optimizers/sgd_optimizer';
5656
export {DataToGPUOptions, DataToGPUWebGLOption, GPUData, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, Variable} from './tensor';
5757
export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, TensorContainerObject} from './tensor_types';
58-
export {BackendValues, DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ScalarLike, ShapeMap, sumOutType, TensorLike, TypedArray, upcastType, WebGLData, WebGPUData} from './types';
58+
export {BackendValues, DataType, DataTypeMap, DataTypeFor, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ScalarLike, ShapeMap, sumOutType, TensorLike, TypedArray, upcastType, WebGLData, WebGPUData} from './types';
5959

6060
export * from './ops/ops';
6161
export {Reduction} from './ops/loss_ops_utils';

tfjs-core/src/kernel_names.ts

+8
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,14 @@ export type SquaredDifferenceInputs = BinaryInputs;
852852
export const Square = 'Square';
853853
export type SquareInputs = Pick<NamedTensorInfoMap, 'x'>;
854854

855+
export const StaticRegexReplace = 'StaticRegexReplace';
856+
export type StaticRegexReplaceInputs = UnaryInputs;
857+
export interface StaticRegexReplaceAttrs {
858+
pattern: string;
859+
rewrite: string;
860+
replaceGlobal: boolean;
861+
}
862+
855863
export const StridedSlice = 'StridedSlice';
856864
export type StridedSliceInputs = Pick<NamedTensorInfoMap, 'x'>;
857865
export interface StridedSliceAttrs {

0 commit comments

Comments
 (0)