Skip to content

Kotlin friendly names #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 18, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
*.iml
.idea
target
55 changes: 43 additions & 12 deletions ndarray/src/main/java/org/tensorflow/ndarray/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.tensorflow.ndarray;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
* The shape of a Tensor or {@link NdArray}.
Expand Down Expand Up @@ -74,8 +76,8 @@ public static Shape scalar() {
* Shape scalar = Shape.of()
* }</pre>
*
* @param dimensionSizes number of elements in each dimension of this shape, if any, or
* {@link Shape#UNKNOWN_SIZE} if unknown.
* @param dimensionSizes number of elements in each dimension of this shape, if any, or {@link
* Shape#UNKNOWN_SIZE} if unknown.
* @return a new shape
*/
public static Shape of(long... dimensionSizes) {
Expand Down Expand Up @@ -108,13 +110,14 @@ public long size() {
* an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
*
* @param i the index of the dimension to get the size for. If this Shape has a known number of
* dimensions, it must be &lt; {@link Shape#numDimensions()}. The index may be negative, in which
* case the position is counted from the end of the shape. E.g.: {@code size(-1)} returns the
* size of the last dimension, {@code size(-2)} the size of the second to last dimension etc.
* dimensions, it must be &lt; {@link Shape#numDimensions()}. The index may be negative, in
* which case the position is counted from the end of the shape. E.g.: {@code size(-1)}
* returns the size of the last dimension, {@code size(-2)} the size of the second to last
* dimension etc.
* @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
* otherwise.
*/
public long size(int i) {
public long get(int i) {
if (dimensionSizes == null) {
return UNKNOWN_SIZE;
} else if (i >= 0) {
Expand Down Expand Up @@ -177,6 +180,24 @@ public long[] asArray() {
}
}

/**
* Returns a defensive copy of the this Shape's axes. Changes to the returned list do not change
* this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
*/
public List<Long> toListOrNull() {
long[] array = asArray();
if (array == null) {
return null;
}

List<Long> list = new ArrayList<>(array.length);
for (long l : array) {
list.add(l);
}

return list;
}

@Override
public int hashCode() {
return dimensionSizes != null ? Arrays.hashCode(dimensionSizes) : super.hashCode();
Expand All @@ -186,6 +207,7 @@ public int hashCode() {
* Equals implementation for Shapes. Two Shapes are considered equal iff:
*
* <p>
*
* <ul>
* <li>the number of dimensions is defined and equal for both
* <li>the size of each dimension is defined and equal for both
Expand Down Expand Up @@ -236,7 +258,8 @@ public Shape head() {
* Returns an n-dimensional Shape with the dimensions matching the first n dimensions of this
* shape
*
* @param n the number of leading dimensions to get, must be &lt;= than {@link Shape#numDimensions()}
* @param n the number of leading dimensions to get, must be &lt;= than {@link
* Shape#numDimensions()}
* @return an n-dimensional Shape with the first n dimensions matching the first n dimensions of
* this Shape
*/
Expand All @@ -252,7 +275,9 @@ public Shape take(int n) {

/** Returns a new Shape, with this Shape's first dimension removed. */
public Shape tail() {
if (dimensionSizes.length < 2) return Shape.of();
if (dimensionSizes.length < 2) {
return Shape.of();
}
return Shape.of(Arrays.copyOfRange(dimensionSizes, 1, dimensionSizes.length));
}

Expand All @@ -276,15 +301,21 @@ public Shape takeLast(int n) {
}

/**
* Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code begin} to {@code end}.
* Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code
* begin} to {@code end}.
*
* @param begin Where to start the sub-shape.
* @param end Where to end the sub-shape, exclusive.
* @return the sub-shape bounded by begin and end.
*/
public Shape subShape(int begin, int end){
public Shape subShape(int begin, int end) {
if (end > numDimensions()) {
throw new ArrayIndexOutOfBoundsException(
"End index " + end + " out of bounds: shape only has " + numDimensions() + " dimensions.");
"End index "
+ end
+ " out of bounds: shape only has "
+ numDimensions()
+ " dimensions.");
}
if (begin < 0) {
throw new ArrayIndexOutOfBoundsException(
Expand Down Expand Up @@ -423,7 +454,7 @@ public boolean isCompatibleWith(Shape shape) {
return false;
}
for (int i = 0; i < numDimensions(); i++) {
if (!isCompatible(size(i), shape.size(i))) {
if (!isCompatible(get(i), shape.get(i))) {
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3798,7 +3798,7 @@ private static int[] computeArrayDims(NdArray<?> ndArray, int expectedRank) {
}
int[] arrayShape = new int[expectedRank];
for (int i = 0; i < expectedRank; ++i) {
long dimSize = shape.size(i);
long dimSize = shape.get(i);
if (dimSize > Integer.MAX_VALUE) {
throw new IllegalArgumentException("Dimension " + i + " is too large to fit in a standard array (" + shape.size(i) + ")");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public static DimensionalSpace create(Shape shape) {

// Start from the last dimension, where all elements are continuous
for (int i = dimensions.length - 1, elementSize = 1; i >= 0; --i) {
dimensions[i] = new Axis(shape.size(i), elementSize);
dimensions[i] = new Axis(shape.get(i), elementSize);
elementSize *= dimensions[i].numElements();
}
return new DimensionalSpace(dimensions, shape);
Expand Down Expand Up @@ -189,7 +189,9 @@ public long positionOf(long[] coords) {
return position;
}

/** Succinct description of the shape meant for debugging. */
/**
* Succinct description of the shape meant for debugging.
*/
@Override
public String toString() {
return Arrays.toString(dimensions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import static org.tensorflow.ndarray.index.Indices.at;
import static org.tensorflow.ndarray.index.Indices.even;
import static org.tensorflow.ndarray.index.Indices.flip;
import static org.tensorflow.ndarray.index.Indices.sliceFrom;
import static org.tensorflow.ndarray.index.Indices.odd;
import static org.tensorflow.ndarray.index.Indices.range;
import static org.tensorflow.ndarray.index.Indices.seq;
import static org.tensorflow.ndarray.index.Indices.sliceFrom;
import static org.tensorflow.ndarray.index.Indices.sliceTo;

import java.nio.BufferOverflowException;
Expand Down Expand Up @@ -132,15 +132,15 @@ public void iterateElements() {
long value = 0L;
for (NdArray<T> matrix : matrix3d.elements(0)) {
assertEquals(2L, matrix.shape().numDimensions());
assertEquals(4L, matrix.shape().size(0));
assertEquals(5L, matrix.shape().size(1));
assertEquals(4L, matrix.shape().get(0));
assertEquals(5L, matrix.shape().get(1));

for (NdArray<T> vector : matrix.elements(0)) {
assertEquals(1L, vector.shape().numDimensions()) ;
assertEquals(5L, vector.shape().size(0));
assertEquals(1L, vector.shape().numDimensions());
assertEquals(5L, vector.shape().get(0));

for (NdArray<T> scalar : vector.scalars()) {
assertEquals(0L, scalar.shape().numDimensions()) ;
assertEquals(0L, scalar.shape().numDimensions());
scalar.setObject(valueOf(value++));
try {
scalar.elements(0);
Expand All @@ -162,7 +162,7 @@ public void iterateElements() {
@Test
public void slices() {
NdArray<T> matrix3d = allocate(Shape.of(5, 4, 5));

T val100 = valueOf(100L);
matrix3d.setObject(val100, 1, 0, 0);
T val101 = valueOf(101L);
Expand Down Expand Up @@ -318,8 +318,8 @@ public void equalsAndHashCode() {
NdArray<T> array4 = allocate(Shape.of(1, 2, 2));

@SuppressWarnings("unchecked")
T[][][] values = (T[][][])(new Object[][][] {
{ { valueOf(0L), valueOf(1L) }, { valueOf(2L), valueOf(0L) } }
T[][][] values = (T[][][]) (new Object[][][]{
{{valueOf(0L), valueOf(1L)}, {valueOf(2L), valueOf(0L)}}
});

StdArrays.copyTo(values[0], array1);
Expand Down
30 changes: 18 additions & 12 deletions ndarray/src/test/java/org/tensorflow/ndarray/ShapeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,38 @@
*/
package org.tensorflow.ndarray;

import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.Test;

public class ShapeTest {

@Test
public void allKnownDimensions() {
Shape shape = Shape.of(5, 4, 5);
assertEquals(3, shape.numDimensions());
assertEquals(5, shape.size(0));
assertEquals(4, shape.size(1));
assertEquals(5, shape.size(2));
assertEquals(5, shape.get(0));
assertEquals(4, shape.get(1));
assertEquals(5, shape.get(2));
assertEquals(100, shape.size());
assertArrayEquals(new long[] {5, 4, 5}, shape.asArray());
assertArrayEquals(new long[]{5, 4, 5}, shape.asArray());
try {
shape.size(3);
shape.get(3);
fail();
} catch (IndexOutOfBoundsException e) {
// as expected
}
assertEquals(5, shape.size(-1));
assertEquals(4, shape.size(-2));
assertEquals(5, shape.size(-3));
assertEquals(5, shape.get(-1));
assertEquals(4, shape.get(-2));
assertEquals(5, shape.get(-3));
try {
shape.size(-4);
shape.get(-4);
fail();
} catch (IndexOutOfBoundsException e) {
// as expected
Expand Down Expand Up @@ -133,7 +139,7 @@ public void testShapeModification() {
long[] internalShape = one.asArray();
assertNotNull(internalShape);
internalShape[0] = 42L;
assertEquals(2L, one.size(0));
assertEquals(2L, one.get(0));
}

@Test
Expand Down