Skip to content

Commit e9ff657

Browse files
authored
Viewing arrays with different shapes (#18)
1 parent 05202d9 commit e9ff657

20 files changed

+112
-14
lines changed

ndarray/src/main/java/org/tensorflow/ndarray/BooleanNdArray.java

+3
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ public interface BooleanNdArray extends NdArray<Boolean> {
6868
*/
6969
BooleanNdArray setBoolean(boolean value, long... coordinates);
7070

71+
@Override
72+
BooleanNdArray withShape(Shape shape);
73+
7174
@Override
7275
BooleanNdArray slice(Index... indices);
7376

ndarray/src/main/java/org/tensorflow/ndarray/ByteNdArray.java

+3
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ public interface ByteNdArray extends NdArray<Byte> {
6868
*/
6969
ByteNdArray setByte(byte value, long... coordinates);
7070

71+
@Override
72+
ByteNdArray withShape(Shape shape);
73+
7174
@Override
7275
ByteNdArray slice(Index... indices);
7376

ndarray/src/main/java/org/tensorflow/ndarray/DoubleNdArray.java

+3
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ default DoubleStream streamOfDoubles() {
8383
return StreamSupport.stream(scalars().spliterator(), false).mapToDouble(DoubleNdArray::getDouble);
8484
}
8585

86+
@Override
87+
DoubleNdArray withShape(Shape shape);
88+
8689
@Override
8790
DoubleNdArray slice(Index... indices);
8891

ndarray/src/main/java/org/tensorflow/ndarray/FloatNdArray.java

+3
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ public interface FloatNdArray extends NdArray<Float> {
6868
*/
6969
FloatNdArray setFloat(float value, long... coordinates);
7070

71+
@Override
72+
FloatNdArray withShape(Shape shape);
73+
7174
@Override
7275
FloatNdArray slice(Index... coordinates);
7376

ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java

+3
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ default IntStream streamOfInts() {
8383
return StreamSupport.stream(scalars().spliterator(), false).mapToInt(IntNdArray::getInt);
8484
}
8585

86+
@Override
87+
IntNdArray withShape(Shape shape);
88+
8689
@Override
8790
IntNdArray slice(Index... indices);
8891

ndarray/src/main/java/org/tensorflow/ndarray/LongNdArray.java

+3
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ default LongStream streamOfLongs() {
8383
return StreamSupport.stream(scalars().spliterator(), false).mapToLong(LongNdArray::getLong);
8484
}
8585

86+
@Override
87+
LongNdArray withShape(Shape shape);
88+
8689
@Override
8790
LongNdArray slice(Index... indices);
8891

ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java

+29-3
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
*/
1717
package org.tensorflow.ndarray;
1818

19+
import org.tensorflow.ndarray.buffer.DataBuffer;
20+
import org.tensorflow.ndarray.index.Index;
21+
1922
import java.util.function.BiConsumer;
2023
import java.util.function.Consumer;
2124
import java.util.stream.Stream;
2225
import java.util.stream.StreamSupport;
2326

24-
import org.tensorflow.ndarray.buffer.DataBuffer;
25-
import org.tensorflow.ndarray.index.Index;
26-
2727
/**
2828
* A data structure of N-dimensions.
2929
*
@@ -101,6 +101,32 @@ public interface NdArray<T> extends Shaped {
101101
*/
102102
NdArraySequence<? extends NdArray<T>> scalars();
103103

104+
/**
105+
* Returns a new N-dimensional view of this array with the given {@code shape}.
106+
*
107+
* <p>The provided {@code shape} must comply to the following characteristics:
108+
* <ul>
109+
* <li>new shape is known (i.e. has no unknown dimension)</li>
110+
* <li>new shape size is equal to the size of the current shape (i.e. same number of elements)</li>
111+
* </ul>
112+
* For example,
113+
* <pre>{@code
114+
* NdArrays.ofInts(Shape.scalar()).withShape(Shape.of(1, 1)); // ok
115+
* NdArrays.ofInts(Shape.of(2, 3).withShape(Shape.of(3, 2)); // ok
116+
* NdArrays.ofInts(Shape.scalar()).withShape(Shape.of(1, 2)); // not ok, sizes are different (1 != 2)
117+
* NdArrays.ofInts(Shape.of(2, 3)).withShape(Shape.unknown()); // not ok, new shape unknown
118+
* }</pre>
119+
*
120+
* <p>Any changes applied to the returned view affect the data of this array as well, as there
121+
* is no copy involved.
122+
*
123+
* @param shape the new shape to apply
124+
* @return a new array viewing the data according to the new shape, or this array if shapes are the same
125+
* @throws IllegalArgumentException if the provided {@code shape} is not compliant
126+
* @throws UnsupportedOperationException if this array does not support this operation
127+
*/
128+
NdArray<T> withShape(Shape shape);
129+
104130
/**
105131
* Creates a multi-dimensional view (or slice) of this array by mapping one or more dimensions
106132
* to the given index selectors.

ndarray/src/main/java/org/tensorflow/ndarray/ShortNdArray.java

+3
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ public interface ShortNdArray extends NdArray<Short> {
6868
*/
6969
ShortNdArray setShort(short value, long... coordinates);
7070

71+
@Override
72+
ShortNdArray withShape(Shape shape);
73+
7174
@Override
7275
ShortNdArray slice(Index... coordinates);
7376

ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java

+15-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import org.tensorflow.ndarray.NdArray;
2020
import org.tensorflow.ndarray.NdArraySequence;
21+
import org.tensorflow.ndarray.Shape;
2122
import org.tensorflow.ndarray.impl.AbstractNdArray;
2223
import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace;
2324
import org.tensorflow.ndarray.impl.sequence.FastElementSequence;
@@ -43,18 +44,29 @@ public NdArraySequence<U> elements(int dimensionIdx) {
4344
DimensionalSpace elemDims = dimensions().from(dimensionIdx + 1);
4445
try {
4546
DataBufferWindow<? extends DataBuffer<T>> elemWindow = buffer().window(elemDims.physicalSize());
46-
U element = instantiate(elemWindow.buffer(), elemDims);
47+
U element = instantiateView(elemWindow.buffer(), elemDims);
4748
return new FastElementSequence(this, dimensionIdx, element, elemWindow);
4849
} catch (UnsupportedOperationException e) {
4950
// If buffer windows are not supported, fallback to slicing (and slower) sequence
5051
return new SlicingElementSequence<>(this, dimensionIdx, elemDims);
5152
}
5253
}
5354

55+
@Override
56+
public U withShape(Shape shape) {
57+
if (shape == null || shape.isUnknown() || shape.size() != this.shape().size()) {
58+
throw new IllegalArgumentException("Shape " + shape + " cannot be used to reshape ndarray of shape " + this.shape());
59+
}
60+
if (shape.equals(this.shape())) {
61+
return (U)this;
62+
}
63+
return instantiateView(buffer(), DimensionalSpace.create(shape));
64+
}
65+
5466
@Override
5567
public U slice(long position, DimensionalSpace sliceDimensions) {
5668
DataBuffer<T> sliceBuffer = buffer().slice(position, sliceDimensions.physicalSize());
57-
return instantiate(sliceBuffer, sliceDimensions);
69+
return instantiateView(sliceBuffer, sliceDimensions);
5870
}
5971

6072
@Override
@@ -147,7 +159,7 @@ protected AbstractDenseNdArray(DimensionalSpace dimensions) {
147159

148160
abstract protected DataBuffer<T> buffer();
149161

150-
abstract U instantiate(DataBuffer<T> buffer, DimensionalSpace dimensions);
162+
abstract U instantiateView(DataBuffer<T> buffer, DimensionalSpace dimensions);
151163

152164
long positionOf(long[] coords, boolean isValue) {
153165
if (coords == null || coords.length == 0) {

ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ protected BooleanDenseNdArray(BooleanDataBuffer buffer, Shape shape) {
7373
}
7474

7575
@Override
76-
BooleanDenseNdArray instantiate(DataBuffer<Boolean> buffer, DimensionalSpace dimensions) {
76+
BooleanDenseNdArray instantiateView(DataBuffer<Boolean> buffer, DimensionalSpace dimensions) {
7777
return new BooleanDenseNdArray((BooleanDataBuffer)buffer, dimensions);
7878
}
7979

ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ protected ByteDenseNdArray(ByteDataBuffer buffer, Shape shape) {
7373
}
7474

7575
@Override
76-
ByteDenseNdArray instantiate(DataBuffer<Byte> buffer, DimensionalSpace dimensions) {
76+
ByteDenseNdArray instantiateView(DataBuffer<Byte> buffer, DimensionalSpace dimensions) {
7777
return new ByteDenseNdArray((ByteDataBuffer)buffer, dimensions);
7878
}
7979

ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ protected DenseNdArray(DataBuffer<T> buffer, Shape shape) {
4545
}
4646

4747
@Override
48-
DenseNdArray<T> instantiate(DataBuffer<T> buffer, DimensionalSpace dimensions) {
48+
DenseNdArray<T> instantiateView(DataBuffer<T> buffer, DimensionalSpace dimensions) {
4949
return new DenseNdArray<>(buffer, dimensions);
5050
}
5151

ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ protected DoubleDenseNdArray(DoubleDataBuffer buffer, Shape shape) {
7373
}
7474

7575
@Override
76-
DoubleDenseNdArray instantiate(DataBuffer<Double> buffer, DimensionalSpace dimensions) {
76+
DoubleDenseNdArray instantiateView(DataBuffer<Double> buffer, DimensionalSpace dimensions) {
7777
return new DoubleDenseNdArray((DoubleDataBuffer)buffer, dimensions);
7878
}
7979

ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ protected FloatDenseNdArray(FloatDataBuffer buffer, Shape shape) {
7373
}
7474

7575
@Override
76-
FloatDenseNdArray instantiate(DataBuffer<Float> buffer, DimensionalSpace dimensions) {
76+
FloatDenseNdArray instantiateView(DataBuffer<Float> buffer, DimensionalSpace dimensions) {
7777
return new FloatDenseNdArray((FloatDataBuffer) buffer, dimensions);
7878
}
7979

ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ protected IntDenseNdArray(IntDataBuffer buffer, Shape shape) {
7373
}
7474

7575
@Override
76-
IntDenseNdArray instantiate(DataBuffer<Integer> buffer, DimensionalSpace dimensions) {
76+
IntDenseNdArray instantiateView(DataBuffer<Integer> buffer, DimensionalSpace dimensions) {
7777
return new IntDenseNdArray((IntDataBuffer)buffer, dimensions);
7878
}
7979

ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ protected LongDenseNdArray(LongDataBuffer buffer, Shape shape) {
7373
}
7474

7575
@Override
76-
LongDenseNdArray instantiate(DataBuffer<Long> buffer, DimensionalSpace dimensions) {
76+
LongDenseNdArray instantiateView(DataBuffer<Long> buffer, DimensionalSpace dimensions) {
7777
return new LongDenseNdArray((LongDataBuffer)buffer, dimensions);
7878
}
7979

ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ protected ShortDenseNdArray(ShortDataBuffer buffer, Shape shape) {
7373
}
7474

7575
@Override
76-
ShortDenseNdArray instantiate(DataBuffer<Short> buffer, DimensionalSpace dimensions) {
76+
ShortDenseNdArray instantiateView(DataBuffer<Short> buffer, DimensionalSpace dimensions) {
7777
return new ShortDenseNdArray((ShortDataBuffer)buffer, dimensions);
7878
}
7979

ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java

+5
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,11 @@ protected long[] getIndicesCoordinates(LongNdArray l) {
212212
*/
213213
public abstract U toDense();
214214

215+
@Override
216+
public U withShape(Shape shape) {
217+
throw new UnsupportedOperationException("Sparse NdArrays cannot be viewed with a different shape");
218+
}
219+
215220
/** {@inheritDoc} */
216221
@Override
217222
public NdArray<T> slice(Index... indices) {

ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java

+27
Original file line numberDiff line numberDiff line change
@@ -384,4 +384,31 @@ public void streamingObjects() {
384384
values = matrix.streamOfObjects().collect(Collectors.toList());
385385
assertIterableEquals(List.of(valueOf(1L), valueOf(2L), valueOf(3L), valueOf(4L)), values);
386386
}
387+
388+
@Test
389+
public void withShape() {
390+
Shape originalShape = Shape.scalar();
391+
Shape newShape = originalShape.prepend(1).prepend(1); // [1, 1]
392+
393+
NdArray<T> originalArray = allocate(originalShape);
394+
originalArray.setObject(valueOf(10L));
395+
assertEquals(valueOf(10L), originalArray.getObject());
396+
397+
NdArray<T> newArray = originalArray.withShape(newShape);
398+
assertNotNull(newArray);
399+
assertEquals(newShape, newArray.shape());
400+
assertEquals(originalShape, originalArray.shape());
401+
assertEquals(valueOf(10L), newArray.getObject(0, 0));
402+
403+
NdArray<T> sameArray = originalArray.withShape(Shape.scalar());
404+
assertSame(originalArray, sameArray);
405+
406+
assertThrows(IllegalArgumentException.class, () -> originalArray.withShape(Shape.of(2)));
407+
assertThrows(IllegalArgumentException.class, () -> originalArray.withShape(Shape.unknown()));
408+
409+
NdArray<T> originalMatrix = allocate(Shape.of(2, 3));
410+
assertThrows(IllegalArgumentException.class, () -> originalMatrix.withShape(Shape.scalar()));
411+
NdArray<T> newMatrix = originalMatrix.withShape(Shape.of(3, 2));
412+
assertEquals(Shape.of(3, 2), newMatrix.shape());
413+
}
387414
}

ndarray/src/test/java/org/tensorflow/ndarray/SparseNdArrayTest.java

+7
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import static org.junit.jupiter.api.Assertions.assertEquals;
2727
import static org.junit.jupiter.api.Assertions.assertFalse;
28+
import static org.junit.jupiter.api.Assertions.assertThrows;
2829
import static org.junit.jupiter.api.Assertions.assertTrue;
2930

3031
public class SparseNdArrayTest {
@@ -188,4 +189,10 @@ public void testShort() {
188189
assertEquals((short) 0, instance.getShort(2, 2));
189190
assertEquals((short) 0xff00, instance.getShort(2, 3));
190191
}
192+
193+
@Test
194+
public void withShape() {
195+
NdArray<?> sparseArray = NdArrays.sparseOf(indices, NdArrays.vectorOf(1, 2, 3), shape);
196+
assertThrows(UnsupportedOperationException.class, () -> sparseArray.withShape(shape.prepend(1)));
197+
}
191198
}

0 commit comments

Comments
 (0)