Skip to content

Commit f4271a5

Browse files
authored
[tfjs-node] fixed range issue for int32 tensor with size larger than 2 ^ 24 (#7931)
* fixed range issue when creating int32 tensor with larger than 2 ^ 24 size * reduce the size of the test tensor * remove use of max op * fixed test * fixed the test
1 parent a0115ea commit f4271a5

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

tfjs-core/src/ops/range_test.ts

+9
Original file line numberDiff line numberDiff line change
@@ -148,4 +148,13 @@ describeWithFlags('range', ALL_ENVS, () => {
148148
expect(a.dtype).toEqual('int32');
149149
expect(a.shape).toEqual([3]);
150150
});
151+
152+
it('should support large number for int32 dtype', async () => {
153+
const length = Math.pow(2, 24) + 10;
154+
const a = tf.range(1, length, 1, 'int32');
155+
const data = await a.data();
156+
expect(data[length - 2]).toEqual(length - 1);
157+
expect(a.dtype).toEqual('int32');
158+
expect(a.shape).toEqual([length - 1]);
159+
});
151160
});

tfjs-node/src/kernels/Range.ts

+4-6
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,16 @@ export const rangeConfig: KernelConfig = {
4444
}
4545

4646
const opAttrs = [createTensorsTypeOpAttr('Tidx', dtype)];
47-
const startTensor = scalar(start);
48-
const stopTensor = scalar(stop);
49-
const stepTensor = scalar(step);
47+
const startTensor = scalar(start, dtype);
48+
const stopTensor = scalar(stop, dtype);
49+
const stepTensor = scalar(step, dtype);
5050
const res = backend.executeSingleOutput(
5151
Range, opAttrs, [startTensor, stopTensor, stepTensor]);
52-
const castedRes = res.cast(dtype);
5352

5453
startTensor.dispose();
5554
stopTensor.dispose();
5655
stepTensor.dispose();
57-
res.dispose();
5856

59-
return castedRes;
57+
return res;
6058
}
6159
};

0 commit comments

Comments
 (0)