|
1 | 1 | /*
|
2 |
| - Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| 2 | + Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved. |
3 | 3 |
|
4 | 4 | Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | you may not use this file except in compliance with the License.
|
|
16 | 16 | */
|
17 | 17 | package org.tensorflow.ndarray;
|
18 | 18 |
|
19 |
| -import static org.junit.jupiter.api.Assertions.assertEquals; |
20 |
| -import static org.junit.jupiter.api.Assertions.assertNotEquals; |
21 |
| -import static org.junit.jupiter.api.Assertions.fail; |
| 19 | +import static org.junit.jupiter.api.Assertions.*; |
22 | 20 | import static org.tensorflow.ndarray.NdArrays.vectorOfObjects;
|
23 | 21 | import static org.tensorflow.ndarray.index.Indices.all;
|
24 | 22 | import static org.tensorflow.ndarray.index.Indices.at;
|
|
32 | 30 |
|
33 | 31 | import java.nio.BufferOverflowException;
|
34 | 32 | import java.nio.BufferUnderflowException;
|
| 33 | +import java.util.List; |
| 34 | +import java.util.stream.Collectors; |
| 35 | + |
35 | 36 | import org.junit.jupiter.api.Test;
|
36 | 37 | import org.tensorflow.ndarray.buffer.DataBuffer;
|
37 | 38 | import org.tensorflow.ndarray.index.Indices;
|
@@ -358,4 +359,29 @@ public void iterateScalarsOnSegmentedElements() {
|
358 | 359 | });
|
359 | 360 | });
|
360 | 361 | }
|
| 362 | + |
| 363 | + @Test |
| 364 | + public void streamingObjects() { |
| 365 | + NdArray<T> scalar = allocate(Shape.scalar()); |
| 366 | + scalar.setObject(valueOf(1L)); |
| 367 | + var values = scalar.streamOfObjects().collect(Collectors.toList()); |
| 368 | + assertIterableEquals(List.of(valueOf(1L)), values); |
| 369 | + |
| 370 | + NdArray<T> vector = allocate(Shape.of(5)); |
| 371 | + vector.setObject(valueOf(1L), 0); |
| 372 | + vector.setObject(valueOf(2L), 1); |
| 373 | + vector.setObject(valueOf(3L), 2); |
| 374 | + vector.setObject(valueOf(4L), 3); |
| 375 | + vector.setObject(valueOf(5L), 4); |
| 376 | + values = vector.streamOfObjects().collect(Collectors.toList()); |
| 377 | + assertIterableEquals(List.of(valueOf(1L), valueOf(2L), valueOf(3L), valueOf(4L), valueOf(5L)), values); |
| 378 | + |
| 379 | + NdArray<T> matrix = allocate(Shape.of(2, 2)); |
| 380 | + matrix.setObject(valueOf(1L), 0, 0); |
| 381 | + matrix.setObject(valueOf(2L), 0, 1); |
| 382 | + matrix.setObject(valueOf(3L), 1, 0); |
| 383 | + matrix.setObject(valueOf(4L), 1, 1); |
| 384 | + values = matrix.streamOfObjects().collect(Collectors.toList()); |
| 385 | + assertIterableEquals(List.of(valueOf(1L), valueOf(2L), valueOf(3L), valueOf(4L)), values); |
| 386 | + } |
361 | 387 | }
|
0 commit comments