Skip to content

Commit a9d3f25

Browse files
authored
Value streaming for NdArrays (#15)
1 parent f54dd15 commit a9d3f25

File tree

7 files changed

+150
-12
lines changed

7 files changed

+150
-12
lines changed

ndarray/src/main/java/org/tensorflow/ndarray/DoubleNdArray.java

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -20,6 +20,9 @@
2020
import org.tensorflow.ndarray.buffer.DoubleDataBuffer;
2121
import org.tensorflow.ndarray.index.Index;
2222

23+
import java.util.stream.DoubleStream;
24+
import java.util.stream.StreamSupport;
25+
2326
/**
2427
* An {@link NdArray} of doubles.
2528
*/
@@ -68,6 +71,18 @@ public interface DoubleNdArray extends NdArray<Double> {
6871
*/
6972
DoubleNdArray setDouble(double value, long... coordinates);
7073

74+
/**
75+
* Retrieve all scalar values of this array as a stream of doubles.
76+
*
77+
* <p>For {@code rank() > 1} arrays, all vectors of the last dimension are collated so that the scalar values are
78+
* returned in sequential order.</p>
79+
*
80+
* @return scalar values as a stream
81+
*/
82+
default DoubleStream streamOfDoubles() {
83+
return StreamSupport.stream(scalars().spliterator(), false).mapToDouble(DoubleNdArray::getDouble);
84+
}
85+
7186
@Override
7287
DoubleNdArray slice(Index... indices);
7388

ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -20,6 +20,9 @@
2020
import org.tensorflow.ndarray.buffer.IntDataBuffer;
2121
import org.tensorflow.ndarray.index.Index;
2222

23+
import java.util.stream.IntStream;
24+
import java.util.stream.StreamSupport;
25+
2326
/**
2427
* An {@link NdArray} of integers.
2528
*/
@@ -68,6 +71,18 @@ public interface IntNdArray extends NdArray<Integer> {
6871
*/
6972
IntNdArray setInt(int value, long... coordinates);
7073

74+
/**
75+
* Retrieve all scalar values of this array as a stream of integers.
76+
*
77+
* <p>For {@code rank() > 1} arrays, all vectors of the last dimension are collated so that the scalar values are
78+
* returned in sequential order.</p>
79+
*
80+
* @return scalar values as a stream
81+
*/
82+
default IntStream streamOfInts() {
83+
return StreamSupport.stream(scalars().spliterator(), false).mapToInt(IntNdArray::getInt);
84+
}
85+
7186
@Override
7287
IntNdArray slice(Index... indices);
7388

ndarray/src/main/java/org/tensorflow/ndarray/LongNdArray.java

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -20,6 +20,9 @@
2020
import org.tensorflow.ndarray.buffer.LongDataBuffer;
2121
import org.tensorflow.ndarray.index.Index;
2222

23+
import java.util.stream.LongStream;
24+
import java.util.stream.StreamSupport;
25+
2326
/**
2427
* An {@link NdArray} of longs.
2528
*/
@@ -68,6 +71,18 @@ public interface LongNdArray extends NdArray<Long> {
6871
*/
6972
LongNdArray setLong(long value, long... coordinates);
7073

74+
/**
75+
* Retrieve all scalar values of this array as a stream of longs.
76+
*
77+
* <p>For {@code rank() > 1} arrays, all vectors of the last dimension are collated so that the scalar values are
78+
* returned in sequential order.</p>
79+
*
80+
* @return scalar values as a stream
81+
*/
82+
default LongStream streamOfLongs() {
83+
return StreamSupport.stream(scalars().spliterator(), false).mapToLong(LongNdArray::getLong);
84+
}
85+
7186
@Override
7287
LongNdArray slice(Index... indices);
7388

ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -18,6 +18,9 @@
1818

1919
import java.util.function.BiConsumer;
2020
import java.util.function.Consumer;
21+
import java.util.stream.Stream;
22+
import java.util.stream.StreamSupport;
23+
2124
import org.tensorflow.ndarray.buffer.DataBuffer;
2225
import org.tensorflow.ndarray.index.Index;
2326

@@ -229,6 +232,18 @@ public interface NdArray<T> extends Shaped {
229232
*/
230233
NdArray<T> setObject(T value, long... coordinates);
231234

235+
/**
236+
* Retrieve all scalar values of this array as a stream of objects.
237+
*
238+
* <p>For {@code rank() > 1} arrays, all vectors of the last dimension are collated so that the scalar values are
239+
* returned in sequential order.</p>
240+
*
241+
* @return scalar values as a stream
242+
*/
243+
default Stream<T> streamOfObjects() {
244+
return StreamSupport.stream(scalars().spliterator(), false).map(NdArray::getObject);
245+
}
246+
232247
/**
233248
* Copy the content of this array to the destination array.
234249
*

ndarray/src/test/java/org/tensorflow/ndarray/IntNdArrayTestBase.java

+29-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -16,10 +16,11 @@
1616
*/
1717
package org.tensorflow.ndarray;
1818

19-
import static org.junit.jupiter.api.Assertions.assertEquals;
20-
2119
import org.junit.jupiter.api.Test;
2220

21+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
22+
import static org.junit.jupiter.api.Assertions.assertEquals;
23+
2324
public abstract class IntNdArrayTestBase extends NdArrayTestBase<Integer> {
2425

2526
@Override
@@ -52,4 +53,29 @@ public void iteratePrimitiveElements() {
5253
assertEquals(9, matrix3d.getInt(0, 0, 4));
5354
assertEquals(7, matrix3d.getInt(0, 1, 2));
5455
}
56+
57+
@Test
58+
public void streamingInts() {
59+
IntNdArray scalar = allocate(Shape.scalar());
60+
scalar.setInt(1);
61+
var values = scalar.streamOfInts().toArray();
62+
assertArrayEquals(new int[]{1}, values);
63+
64+
IntNdArray vector = allocate(Shape.of(5));
65+
vector.setInt(1, 0);
66+
vector.setInt(2, 1);
67+
vector.setInt(3, 2);
68+
vector.setInt(4, 3);
69+
vector.setInt(5, 4);
70+
values = vector.streamOfInts().toArray();
71+
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, values);
72+
73+
IntNdArray matrix = allocate(Shape.of(2, 2));
74+
matrix.setInt(1, 0, 0);
75+
matrix.setInt(2, 0, 1);
76+
matrix.setInt(3, 1, 0);
77+
matrix.setInt(4, 1, 1);
78+
values = matrix.streamOfInts().toArray();
79+
assertArrayEquals(new int[]{1, 2, 3, 4}, values);
80+
}
5581
}

ndarray/src/test/java/org/tensorflow/ndarray/LongNdArrayTestBase.java

+27-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
1616
*/
1717
package org.tensorflow.ndarray;
1818

19+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
1920
import static org.junit.jupiter.api.Assertions.assertEquals;
2021

2122
import org.junit.jupiter.api.Test;
@@ -52,4 +53,29 @@ public void iteratePrimitiveElements() {
5253
assertEquals(9, matrix3d.getLong(0, 0, 4));
5354
assertEquals(7, matrix3d.getLong(0, 1, 2));
5455
}
56+
57+
@Test
58+
public void streamingLongs() {
59+
LongNdArray scalar = allocate(Shape.scalar());
60+
scalar.setLong(1L);
61+
var values = scalar.streamOfLongs().toArray();
62+
assertArrayEquals(new long[]{1L}, values);
63+
64+
LongNdArray vector = allocate(Shape.of(5));
65+
vector.setLong(1L, 0);
66+
vector.setLong(2L, 1);
67+
vector.setLong(3L, 2);
68+
vector.setLong(4L, 3);
69+
vector.setLong(5L, 4);
70+
values = vector.streamOfLongs().toArray();
71+
assertArrayEquals(new long[]{1L, 2L, 3L, 4L, 5L}, values);
72+
73+
LongNdArray matrix = allocate(Shape.of(2, 2));
74+
matrix.setLong(1L, 0, 0);
75+
matrix.setLong(2L, 0, 1);
76+
matrix.setLong(3L, 1, 0);
77+
matrix.setLong(4L, 1, 1);
78+
values = matrix.streamOfLongs().toArray();
79+
assertArrayEquals(new long[]{1L, 2L, 3L, 4L}, values);
80+
}
5581
}

ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java

+30-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -16,9 +16,7 @@
1616
*/
1717
package org.tensorflow.ndarray;
1818

19-
import static org.junit.jupiter.api.Assertions.assertEquals;
20-
import static org.junit.jupiter.api.Assertions.assertNotEquals;
21-
import static org.junit.jupiter.api.Assertions.fail;
19+
import static org.junit.jupiter.api.Assertions.*;
2220
import static org.tensorflow.ndarray.NdArrays.vectorOfObjects;
2321
import static org.tensorflow.ndarray.index.Indices.all;
2422
import static org.tensorflow.ndarray.index.Indices.at;
@@ -32,6 +30,9 @@
3230

3331
import java.nio.BufferOverflowException;
3432
import java.nio.BufferUnderflowException;
33+
import java.util.List;
34+
import java.util.stream.Collectors;
35+
3536
import org.junit.jupiter.api.Test;
3637
import org.tensorflow.ndarray.buffer.DataBuffer;
3738
import org.tensorflow.ndarray.index.Indices;
@@ -358,4 +359,29 @@ public void iterateScalarsOnSegmentedElements() {
358359
});
359360
});
360361
}
362+
363+
@Test
364+
public void streamingObjects() {
365+
NdArray<T> scalar = allocate(Shape.scalar());
366+
scalar.setObject(valueOf(1L));
367+
var values = scalar.streamOfObjects().collect(Collectors.toList());
368+
assertIterableEquals(List.of(valueOf(1L)), values);
369+
370+
NdArray<T> vector = allocate(Shape.of(5));
371+
vector.setObject(valueOf(1L), 0);
372+
vector.setObject(valueOf(2L), 1);
373+
vector.setObject(valueOf(3L), 2);
374+
vector.setObject(valueOf(4L), 3);
375+
vector.setObject(valueOf(5L), 4);
376+
values = vector.streamOfObjects().collect(Collectors.toList());
377+
assertIterableEquals(List.of(valueOf(1L), valueOf(2L), valueOf(3L), valueOf(4L), valueOf(5L)), values);
378+
379+
NdArray<T> matrix = allocate(Shape.of(2, 2));
380+
matrix.setObject(valueOf(1L), 0, 0);
381+
matrix.setObject(valueOf(2L), 0, 1);
382+
matrix.setObject(valueOf(3L), 1, 0);
383+
matrix.setObject(valueOf(4L), 1, 1);
384+
values = matrix.streamOfObjects().collect(Collectors.toList());
385+
assertIterableEquals(List.of(valueOf(1L), valueOf(2L), valueOf(3L), valueOf(4L)), values);
386+
}
361387
}

0 commit comments

Comments
 (0)