Skip to content

Commit 47ae42a

Browse files
committed
Rename Shape.size(int) to get, add toListOrNull
Signed-off-by: Ryan Nett <[email protected]>
1 parent be23015 commit 47ae42a

File tree

20 files changed

+86
-91
lines changed

20 files changed

+86
-91
lines changed

ndarray/src/main/java/org/tensorflow/ndarray/Shape.java

+24-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
package org.tensorflow.ndarray;
1919

20+
import java.util.ArrayList;
2021
import java.util.Arrays;
22+
import java.util.List;
2123

2224
/**
2325
* The shape of a Tensor or {@link NdArray}.
@@ -114,7 +116,7 @@ public long size() {
114116
* @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
115117
* otherwise.
116118
*/
117-
public long size(int i) {
119+
public long get(int i) {
118120
if (dimensionSizes == null) {
119121
return UNKNOWN_SIZE;
120122
} else if (i >= 0) {
@@ -166,7 +168,7 @@ public boolean isUnknown() {
166168
}
167169

168170
/**
169-
* Returns a defensive copy of the this Shape's axes. Changes to the returned array to not change
171+
* Returns a defensive copy of the this Shape's axes. Changes to the returned array do not change
170172
* this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
171173
*/
172174
public long[] asArray() {
@@ -177,6 +179,25 @@ public long[] asArray() {
177179
}
178180
}
179181

182+
183+
/**
184+
* Returns a defensive copy of the this Shape's axes. Changes to the returned list do not change
185+
* this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
186+
*/
187+
public List<Long> toListOrNull() {
188+
long[] array = asArray();
189+
if (array == null){
190+
return null;
191+
}
192+
193+
List<Long> list = new ArrayList<>(array.length);
194+
for(long l : array) {
195+
list.add(l);
196+
}
197+
198+
return list;
199+
}
200+
180201
@Override
181202
public int hashCode() {
182203
return dimensionSizes != null ? Arrays.hashCode(dimensionSizes) : super.hashCode();
@@ -423,7 +444,7 @@ public boolean isCompatibleWith(Shape shape) {
423444
return false;
424445
}
425446
for (int i = 0; i < numDimensions(); i++) {
426-
if (!isCompatible(size(i), shape.size(i))) {
447+
if (!isCompatible(get(i), shape.get(i))) {
427448
return false;
428449
}
429450
}

ndarray/src/main/java/org/tensorflow/ndarray/StdArrays.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -3798,9 +3798,9 @@ private static int[] computeArrayDims(NdArray<?> ndArray, int expectedRank) {
37983798
}
37993799
int[] arrayShape = new int[expectedRank];
38003800
for (int i = 0; i < expectedRank; ++i) {
3801-
long dimSize = shape.size(i);
3801+
long dimSize = shape.get(i);
38023802
if (dimSize > Integer.MAX_VALUE) {
3803-
throw new IllegalArgumentException("Dimension " + i + " is too large to fit in a standard array (" + shape.size(i) + ")");
3803+
throw new IllegalArgumentException("Dimension " + i + " is too large to fit in a standard array (" + shape.get(i) + ")");
38043804
}
38053805
arrayShape[i] = (int)dimSize;
38063806
}

ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public static DimensionalSpace create(Shape shape) {
2828

2929
// Start from the last dimension, where all elements are continuous
3030
for (int i = dimensions.length - 1, elementSize = 1; i >= 0; --i) {
31-
dimensions[i] = new Axis(shape.size(i), elementSize);
31+
dimensions[i] = new Axis(shape.get(i), elementSize);
3232
elementSize *= dimensions[i].numElements();
3333
}
3434
return new DimensionalSpace(dimensions, shape);

ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,12 @@ public void iterateElements() {
132132
long value = 0L;
133133
for (NdArray<T> matrix : matrix3d.elements(0)) {
134134
assertEquals(2L, matrix.shape().numDimensions());
135-
assertEquals(4L, matrix.shape().size(0));
136-
assertEquals(5L, matrix.shape().size(1));
135+
assertEquals(4L, matrix.shape().get(0));
136+
assertEquals(5L, matrix.shape().get(1));
137137

138138
for (NdArray<T> vector : matrix.elements(0)) {
139139
assertEquals(1L, vector.shape().numDimensions()) ;
140-
assertEquals(5L, vector.shape().size(0));
140+
assertEquals(5L, vector.shape().get(0));
141141

142142
for (NdArray<T> scalar : vector.scalars()) {
143143
assertEquals(0L, scalar.shape().numDimensions()) ;

ndarray/src/test/java/org/tensorflow/ndarray/ShapeTest.java

+9-9
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,22 @@ public class ShapeTest {
2626
public void allKnownDimensions() {
2727
Shape shape = Shape.of(5, 4, 5);
2828
assertEquals(3, shape.numDimensions());
29-
assertEquals(5, shape.size(0));
30-
assertEquals(4, shape.size(1));
31-
assertEquals(5, shape.size(2));
29+
assertEquals(5, shape.get(0));
30+
assertEquals(4, shape.get(1));
31+
assertEquals(5, shape.get(2));
3232
assertEquals(100, shape.size());
3333
assertArrayEquals(new long[] {5, 4, 5}, shape.asArray());
3434
try {
35-
shape.size(3);
35+
shape.get(3);
3636
fail();
3737
} catch (IndexOutOfBoundsException e) {
3838
// as expected
3939
}
40-
assertEquals(5, shape.size(-1));
41-
assertEquals(4, shape.size(-2));
42-
assertEquals(5, shape.size(-3));
40+
assertEquals(5, shape.get(-1));
41+
assertEquals(4, shape.get(-2));
42+
assertEquals(5, shape.get(-3));
4343
try {
44-
shape.size(-4);
44+
shape.get(-4);
4545
fail();
4646
} catch (IndexOutOfBoundsException e) {
4747
// as expected
@@ -133,7 +133,7 @@ public void testShapeModification() {
133133
long[] internalShape = one.asArray();
134134
assertNotNull(internalShape);
135135
internalShape[0] = 42L;
136-
assertEquals(2L, one.size(0));
136+
assertEquals(2L, one.get(0));
137137
}
138138

139139
@Test

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ private static TensorInfo toTensorInfo(Output<?> operand) {
121121
Shape shape = operand.shape();
122122
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
123123
for (int i = 0; i < shape.numDimensions(); ++i) {
124-
tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i)));
124+
tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.get(i)));
125125
}
126126
return TensorInfo.newBuilder()
127127
.setDtype(operand.dataType())

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ public static <T extends TNumber> Operand<T> sigmoidCrossEntropyWithLogits(
9595
private static boolean isCompatible(Shape shape, Shape other) {
9696
if (shape.numDimensions() != other.numDimensions()) return false;
9797
for (int i = 0; i < shape.numDimensions(); i++) {
98-
long aShapeDim = shape.size(i);
99-
long bShapeDim = other.size(i);
98+
long aShapeDim = shape.get(i);
99+
long bShapeDim = other.get(i);
100100
if (aShapeDim == bShapeDim
101101
|| (aShapeDim == Shape.UNKNOWN_SIZE || bShapeDim == Shape.UNKNOWN_SIZE)) {
102102
continue;

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java

+4-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import org.tensorflow.types.TFloat32;
1515
import org.tensorflow.types.TInt64;
1616
import org.tensorflow.types.family.TNumber;
17-
import org.tensorflow.types.family.TType;
1817

1918
import java.util.Arrays;
2019
import java.util.List;
@@ -124,10 +123,10 @@ public static <T extends TNumber, U extends TNumber> Operand<T> softmaxCrossEntr
124123
axis = shape.numDimensions() + axis;
125124
}
126125
for (int i = 0; i < axis; i++) {
127-
newArray[i] = shape.size(i);
126+
newArray[i] = shape.get(i);
128127
}
129128
for (int i = axis + 1; i < shape.numDimensions(); i++) {
130-
newArray[i - 1] = shape.size(i);
129+
newArray[i - 1] = shape.get(i);
131130
}
132131
cost = Reshape.create(scope, cost, Constant.vectorOf(scope, newArray));
133132
}
@@ -152,15 +151,15 @@ private static <T extends TNumber> Operand<T> flattenOuterDims(Scope scope, Oper
152151
long product = 1L;
153152
boolean productValid = true;
154153
for (int i = ndims - 2; i >= 0; i--) {
155-
long d = shape.size(i);
154+
long d = shape.get(i);
156155
if (d == org.tensorflow.ndarray.Shape.UNKNOWN_SIZE) {
157156
productValid = false;
158157
break;
159158
}
160159
product *= d;
161160
}
162161
if (productValid) {
163-
return Reshape.create(scope, logits, Constant.arrayOf(scope, product, shape.size(-1)));
162+
return Reshape.create(scope, logits, Constant.arrayOf(scope, product, shape.get(-1)));
164163
}
165164
}
166165

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ public static <T extends TNumber, U extends TNumber> Operand sparseSoftmaxCrossE
140140
}
141141

142142
// Reshape logits to 2 dims, labels to 1 dim.
143-
long numClassses = logitsShape.size(-1);
143+
long numClassses = logitsShape.get(-1);
144144

145145
preciseLogits = Reshape.create(scope, preciseLogits, Constant.arrayOf(scope, -1L, numClassses));
146146
labels = Reshape.create(scope, labels, Constant.scalarOf(scope, -1));

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ public void outputDataTypeAndShape() {
5757
.setAttr("value", t)
5858
.build();
5959
assertEquals(DataType.DT_INT32, op.dtype(0));
60-
assertEquals(2, op.shape(0).size(0));
61-
assertEquals(3, op.shape(0).size(1));
60+
assertEquals(2, op.shape(0).get(0));
61+
assertEquals(3, op.shape(0).get(1));
6262
}
6363
}
6464

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ public void setAttrShape() {
129129
.build()
130130
.output(0);
131131
assertEquals(2, n.shape().numDimensions());
132-
assertEquals(-1, n.shape().size(0));
133-
assertEquals(784, n.shape().size(1));
132+
assertEquals(-1, n.shape().get(0));
133+
assertEquals(784, n.shape().get(1));
134134
assertEquals(DataType.DT_FLOAT, n.dataType());
135135
}
136136
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ public void exportFunctionWithVariables() throws IOException {
146146
assertNotNull(inputInfo);
147147
assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount());
148148
for (int i = 0; i < xyShape.numDimensions(); ++i) {
149-
assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize());
149+
assertEquals(xyShape.get(i), inputInfo.getTensorShape().getDim(i).getSize());
150150
}
151151

152152
TensorInfo outputInfo = signatureDef.getOutputsMap().get("reducedSum");

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java

+16-16
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ public void nDimensional() {
325325
assertEquals(TFloat64.class, t.type());
326326
assertEquals(DataType.DT_DOUBLE, t.dataType());
327327
assertEquals(1, t.shape().numDimensions());
328-
assertEquals(3, t.shape().size(0));
328+
assertEquals(3, t.shape().get(0));
329329
assertEquals(vector, t);
330330
}
331331

@@ -334,8 +334,8 @@ public void nDimensional() {
334334
assertEquals(TInt32.class, t.type());
335335
assertEquals(DataType.DT_INT32, t.dataType());
336336
assertEquals(2, t.shape().numDimensions());
337-
assertEquals(2, t.shape().size(0));
338-
assertEquals(3, t.shape().size(1));
337+
assertEquals(2, t.shape().get(0));
338+
assertEquals(3, t.shape().get(1));
339339
assertEquals(matrix, t);
340340
}
341341

@@ -346,9 +346,9 @@ public void nDimensional() {
346346
assertEquals(TInt64.class, t.type());
347347
assertEquals(DataType.DT_INT64, t.dataType());
348348
assertEquals(3, t.shape().numDimensions());
349-
assertEquals(2, t.shape().size(0));
350-
assertEquals(5, t.shape().size(1));
351-
assertEquals(1, t.shape().size(2));
349+
assertEquals(2, t.shape().get(0));
350+
assertEquals(5, t.shape().get(1));
351+
assertEquals(1, t.shape().get(2));
352352
assertEquals(threeD, t);
353353
}
354354

@@ -361,10 +361,10 @@ public void nDimensional() {
361361
assertEquals(TBool.class, t.type());
362362
assertEquals(DataType.DT_BOOL, t.dataType());
363363
assertEquals(4, t.shape().numDimensions());
364-
assertEquals(3, t.shape().size(0));
365-
assertEquals(1, t.shape().size(1));
366-
assertEquals(2, t.shape().size(2));
367-
assertEquals(4, t.shape().size(3));
364+
assertEquals(3, t.shape().get(0));
365+
assertEquals(1, t.shape().get(1));
366+
assertEquals(2, t.shape().get(2));
367+
assertEquals(4, t.shape().get(3));
368368
assertEquals(fourD, t);
369369
}
370370
}
@@ -381,8 +381,8 @@ public void testNDimensionalStringTensor() {
381381
assertEquals(TString.class, t.type());
382382
assertEquals(DataType.DT_STRING, t.dataType());
383383
assertEquals(2, t.shape().numDimensions());
384-
assertEquals(4, t.shape().size(0));
385-
assertEquals(3, t.shape().size(1));
384+
assertEquals(4, t.shape().get(0));
385+
assertEquals(3, t.shape().get(1));
386386
assertEquals(matrix, t);
387387
}
388388

@@ -392,8 +392,8 @@ public void testNDimensionalStringTensor() {
392392
assertEquals(TString.class, t.type());
393393
assertEquals(DataType.DT_STRING, t.dataType());
394394
assertEquals(2, t.shape().numDimensions());
395-
assertEquals(4, t.shape().size(0));
396-
assertEquals(3, t.shape().size(1));
395+
assertEquals(4, t.shape().get(0));
396+
assertEquals(3, t.shape().get(1));
397397
assertEquals(byteMatrix, t.asBytes());
398398
assertEquals(matrix, t);
399399
}
@@ -406,7 +406,7 @@ public void testUint8TensorFromArray() {
406406
assertEquals(TUint8.class, t.type());
407407
assertEquals(DataType.DT_UINT8, t.dataType());
408408
assertEquals(1, t.shape().numDimensions());
409-
assertEquals(4, t.shape().size(0));
409+
assertEquals(4, t.shape().get(0));
410410

411411
byte[] got = new byte[4];
412412
t.read(DataBuffers.of(got));
@@ -421,7 +421,7 @@ public void testCreateFromArrayOfBoxed() {
421421
assertEquals(TInt32.class, t.type());
422422
assertEquals(DataType.DT_INT32, t.dataType());
423423
assertEquals(1, t.shape().numDimensions());
424-
assertEquals(4, t.shape().size(0));
424+
assertEquals(4, t.shape().get(0));
425425

426426
Integer[] got = new Integer[4];
427427
t.read(DataBuffers.ofObjects(got));

tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ public Operand<T> call(Operand<TInt64> dims, Class<T> type) {
7070
if (shape.numDimensions() != 2) {
7171
throw new IllegalArgumentException("2D matrix required, got " + shape.numDimensions());
7272
}
73-
boolean isSquare = shape.size(0) == shape.size(1);
74-
long diagSize = Math.min(shape.size(0), shape.size(1));
73+
boolean isSquare = shape.get(0) == shape.get(1);
74+
long diagSize = Math.min(shape.get(0), shape.get(1));
7575
Shape diagShape = Shape.of(diagSize);
7676

7777
Operand<T> op;
@@ -83,8 +83,8 @@ public Operand<T> call(Operand<TInt64> dims, Class<T> type) {
8383
tf.linalg.matrixDiag(
8484
diagOnes,
8585
tf.constant(0), // don't cast here, expecting TInt32
86-
tf.constant((int) shape.size(0)),
87-
tf.constant((int) shape.size(1)),
86+
tf.constant((int) shape.get(0)),
87+
tf.constant((int) shape.get(1)),
8888
zero);
8989
} else {
9090
Operand<T> zeroMatrix = tf.zeros(dims, type);

tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ public Operand<T> call(Operand<TInt64> dims, Class<T> type) {
9090
}
9191
long numRows = 1;
9292
int i = 0;
93-
for (; i < dimsShape.numDimensions() - 1; i++) numRows *= dimsShape.size(i);
94-
long numCols = dimsShape.size(i);
93+
for (; i < dimsShape.numDimensions() - 1; i++) numRows *= dimsShape.get(i);
94+
long numCols = dimsShape.get(i);
9595
Shape flatShape = Shape.of(Math.max(numRows, numCols), Math.min(numRows, numCols));
9696
long[] seeds = {seed, 0};
9797
Operand<T> op =

tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ public static <T extends TNumber> Operand<T> sparseCategoricalCrossentropy(
566566
tf.reshape(
567567
predictions,
568568
tf.constant(
569-
new long[] {-1L, predictionsShape.size(predictionsShape.numDimensions() - 1)}));
569+
new long[] {-1L, predictionsShape.get(predictionsShape.numDimensions() - 1)}));
570570
}
571571

572572
@SuppressWarnings("unchecked")
@@ -643,7 +643,7 @@ private static <T extends TNumber> Operand<T> smoothCategoricalLabels(
643643
Operand<T> smoothing = cast(tf, tf.constant(labelSmoothing), labelType);
644644
Shape labelsShape = labels.shape();
645645
int numDims = labelsShape.numDimensions();
646-
Operand<T> numClasses = cast(tf, tf.constant(labelsShape.size(numDims - 1)), labelType);
646+
Operand<T> numClasses = cast(tf, tf.constant(labelsShape.get(numDims - 1)), labelType);
647647
Operand<T> oneMinusSmoothing = cast(tf, tf.constant(1.f - labelSmoothing), labelType);
648648
return tf.math.add(tf.math.mul(labels, oneMinusSmoothing), tf.math.div(smoothing, numClasses));
649649
}

0 commit comments

Comments
 (0)