diff --git a/ndarray/src/main/java/module-info.java b/ndarray/src/main/java/module-info.java index 6b33ed6..10d0b84 100644 --- a/ndarray/src/main/java/module-info.java +++ b/ndarray/src/main/java/module-info.java @@ -21,6 +21,7 @@ exports org.tensorflow.ndarray.buffer; exports org.tensorflow.ndarray.buffer.layout; exports org.tensorflow.ndarray.index; + exports org.tensorflow.ndarray.initializer; // Expose all implementions of our interfaces, so consumers can write custom // implementations easily by extending from them diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/NdArraySequence.java b/ndarray/src/main/java/org/tensorflow/ndarray/NdArraySequence.java index afb930e..b97639a 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/NdArraySequence.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/NdArraySequence.java @@ -33,12 +33,12 @@ public interface NdArraySequence> extends Iterable { /** - * Visit each elements of this iteration and their respective coordinates. + * Visit each element of this iteration and their respective coordinates. * *

Important: the consumer method should not keep a reference to the coordinates * as they might be mutable and reused during the iteration to improve performance. * - * @param consumer method to invoke for each elements + * @param consumer method to invoke for each element */ void forEachIndexed(BiConsumer consumer); @@ -60,7 +60,7 @@ public interface NdArraySequence> extends Iterable { * ndArray.elements(0).asSlices().forEach(e -> vectors::add); // Safe, each `e` is a distinct NdArray instance * } * - * @return a sequence that returns each elements iterated as a new slice + * @return a sequence that returns each element iterated as a new slice * @see DataBufferWindow */ NdArraySequence asSlices(); diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java b/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java index d79b781..aa4e736 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java @@ -33,7 +33,10 @@ import org.tensorflow.ndarray.impl.dense.IntDenseNdArray; import org.tensorflow.ndarray.impl.dense.LongDenseNdArray; import org.tensorflow.ndarray.impl.dense.ShortDenseNdArray; +import org.tensorflow.ndarray.impl.dense.initializer.DenseNdArrayInitializer; +import org.tensorflow.ndarray.impl.dense.initializer.DoubleDenseNdArrayInitializer; import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; import org.tensorflow.ndarray.impl.sparse.BooleanSparseNdArray; import org.tensorflow.ndarray.impl.sparse.ByteSparseNdArray; import org.tensorflow.ndarray.impl.sparse.DoubleSparseNdArray; @@ -41,6 +44,12 @@ import org.tensorflow.ndarray.impl.sparse.IntSparseNdArray; import org.tensorflow.ndarray.impl.sparse.LongSparseNdArray; import org.tensorflow.ndarray.impl.sparse.ShortSparseNdArray; +import org.tensorflow.ndarray.impl.sparse.initializer.DoubleSparseNdArrayInitializer; +import org.tensorflow.ndarray.impl.sparse.initializer.SparseNdArrayInitializer; +import org.tensorflow.ndarray.initializer.DoubleNdArrayInitializer; +import org.tensorflow.ndarray.initializer.NdArrayInitializer; + +import java.util.function.Consumer; /** Utility class for instantiating {@link NdArray} objects. */ public final class NdArrays { @@ -555,6 +564,20 @@ public static DoubleNdArray ofDoubles(Shape shape) { return wrap(shape, DataBuffers.ofDoubles(shape.size())); } + /** + * Creates an N-dimensional array of doubles of the given shape, initializing its data after allocation. + * + * @param shape shape of the array + * @param init invoked to initialize the data of the allocated array + * @return new double N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static DoubleNdArray ofDoubles(Shape shape, Consumer init) { + DoubleDenseNdArray array = (DoubleDenseNdArray)ofDoubles(shape); + init.accept(new DoubleDenseNdArrayInitializer(array)); + return array; + } + /** * Wraps a buffer in a double N-dimensional array of a given shape. * @@ -568,6 +591,23 @@ public static DoubleNdArray wrap(Shape shape, DoubleDataBuffer buffer) { return DoubleDenseNdArray.create(buffer, shape); } + /** + * Creates an Sparse array of doubles of the given shape, hydrating it with data after its allocation + * + * @param shape shape of the array + * @param numValues number of double value actually set in the array, others defaulting to the zero value + * @param hydrate initialize the data of the created array, using a hydrator + * @return new double N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static DoubleSparseNdArray sparseOfDoubles(Shape shape, long numValues, Consumer hydrate) { + LongNdArray indices = ofLongs(Shape.of(numValues, shape.numDimensions())); + DoubleNdArray values = ofDoubles(Shape.of(numValues)); + DoubleSparseNdArray array = DoubleSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + hydrate.accept(new DoubleSparseNdArrayInitializer(array)); + return array; + } + /** * Creates a Sparse array of double values with a default value of zero * @@ -756,6 +796,22 @@ public static NdArray ofObjects(Class clazz, Shape shape) { return wrap(shape, DataBuffers.ofObjects(clazz, shape.size())); } + /** + * Creates an N-dimensional array of objects of the given shape, hydrating it with data after its allocation + * + * @param clazz class of the data to be stored in this array + * @param shape shape of the array + * @param hydrate initialize the data of the created array, using a hydrator + * @param type of object to store in this array + * @return new N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static NdArray ofObjects(Class clazz, Shape shape, Consumer> hydrate) { + var array = (DenseNdArray)ofObjects(clazz, shape); + hydrate.accept(new DenseNdArrayInitializer<>(array)); + return array; + } + /** * Wraps a buffer in an N-dimensional array of a given shape. * @@ -770,6 +826,25 @@ public static NdArray wrap(Shape shape, DataBuffer buffer) { return DenseNdArray.wrap(buffer, shape); } + /** + * Creates a Sparse array of objects of the given shape, hydrating it with data after its allocation + * + * @param type the class type represented by this sparse array. + * @param shape shape of the array + * @param numValues number of values actually set in the array, others defaulting to the zero value + * @param hydrate initialize the data of the created array, using a hydrator + * @param type of object to store in this array + * @return new N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static NdArray sparseOfObjects(Class type, Shape shape, long numValues, Consumer> hydrate) { + LongNdArray indices = ofLongs(Shape.of(numValues, shape.numDimensions())); + NdArray values = ofObjects(type, Shape.of(numValues)); + AbstractSparseNdArray array = (AbstractSparseNdArray)sparseOfObjects(type, indices, values, shape); + hydrate.accept(new SparseNdArrayInitializer<>(array)); + return array; + } + /** * Creates a Sparse array of values with a null default value * @@ -783,6 +858,7 @@ public static NdArray wrap(Shape shape, DataBuffer buffer) { * values=["one", "two"]} specifies that element {@code [1,3,1]} of the sparse NdArray has a * value of "one", and element {@code [2,4,0]} of the NdArray has a value of "two"". All other * values are null. + * @param type of object to store in this array * @param shape the shape of the dense array represented by this sparse array. * @return the float sparse array. */ @@ -807,6 +883,7 @@ public static NdArray sparseOfObjects( * values are null. * @param defaultValue Scalar value to set for indices not specified in 'indices' * @param shape the shape of the dense array represented by this sparse array. + * @param type of object to store in this array * @return the float sparse array. */ public static NdArray sparseOfObjects( diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/Validator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/Validator.java index 285d099..11c8453 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/Validator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/Validator.java @@ -16,11 +16,12 @@ */ package org.tensorflow.ndarray.impl; -import java.nio.BufferOverflowException; -import java.nio.BufferUnderflowException; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.buffer.DataBuffer; +import java.nio.BufferOverflowException; +import java.nio.BufferUnderflowException; + public class Validator { public static void copyToNdArrayArgs(NdArray ndArray, NdArray otherNdArray) { @@ -42,14 +43,5 @@ public static void writeFromBufferArgs(NdArray ndArray, DataBuffer src) { } } - private static void copyArrayArgs(int arrayLength, int arrayOffset) { - if (arrayOffset < 0) { - throw new IndexOutOfBoundsException("Offset must be non-negative"); - } - if (arrayOffset > arrayLength) { - throw new IndexOutOfBoundsException("Offset must be no larger than array length"); - } - } - protected Validator() {} } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java index 30af952..9ef3749 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java @@ -31,6 +31,20 @@ @SuppressWarnings("unchecked") public abstract class AbstractDenseNdArray> extends AbstractNdArray { + abstract public DataBuffer buffer(); + + public NdArraySequence elementsAt(long[] startCoords) { + DimensionalSpace elemDims = dimensions().from(startCoords.length); + try { + DataBufferWindow> elemWindow = buffer().window(elemDims.physicalSize()); + U element = instantiate(elemWindow.buffer(), elemDims); + return new FastElementSequence(this, startCoords, element, elemWindow); + } catch (UnsupportedOperationException e) { + // If buffer windows are not supported, fallback to slicing (and slower) sequence + return new SlicingElementSequence(this, startCoords, elemDims); + } + } + @Override public NdArraySequence elements(int dimensionIdx) { if (dimensionIdx >= shape().numDimensions()) { @@ -40,15 +54,7 @@ public NdArraySequence elements(int dimensionIdx) { if (rank() == 0 && dimensionIdx < 0) { return new SingleElementSequence<>(this); } - DimensionalSpace elemDims = dimensions().from(dimensionIdx + 1); - try { - DataBufferWindow> elemWindow = buffer().window(elemDims.physicalSize()); - U element = instantiate(elemWindow.buffer(), elemDims); - return new FastElementSequence(this, dimensionIdx, element, elemWindow); - } catch (UnsupportedOperationException e) { - // If buffer windows are not supported, fallback to slicing (and slower) sequence - return new SlicingElementSequence<>(this, dimensionIdx, elemDims); - } + return elementsAt(new long[dimensionIdx + 1]); } @Override @@ -145,8 +151,6 @@ protected AbstractDenseNdArray(DimensionalSpace dimensions) { super(dimensions); } - abstract protected DataBuffer buffer(); - abstract U instantiate(DataBuffer buffer, DimensionalSpace dimensions); long positionOf(long[] coords, boolean isValue) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java index 0764146..a3caf40 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java @@ -31,6 +31,11 @@ public static BooleanNdArray create(BooleanDataBuffer buffer, Shape shape) { return new BooleanDenseNdArray(buffer, shape); } + @Override + public BooleanDataBuffer buffer() { + return buffer; + } + @Override public boolean getBoolean(long... indices) { return buffer.getBoolean(positionOf(indices, true)); @@ -77,11 +82,6 @@ BooleanDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dim return new BooleanDenseNdArray((BooleanDataBuffer)buffer, dimensions); } - @Override - protected BooleanDataBuffer buffer() { - return buffer; - } - private final BooleanDataBuffer buffer; private BooleanDenseNdArray(BooleanDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java index 172432b..fa3b722 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java @@ -31,6 +31,11 @@ public static ByteNdArray create(ByteDataBuffer buffer, Shape shape) { return new ByteDenseNdArray(buffer, shape); } + @Override + public ByteDataBuffer buffer() { + return buffer; + } + @Override public byte getByte(long... indices) { return buffer.getByte(positionOf(indices, true)); @@ -77,11 +82,6 @@ ByteDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimension return new ByteDenseNdArray((ByteDataBuffer)buffer, dimensions); } - @Override - protected ByteDataBuffer buffer() { - return buffer; - } - private final ByteDataBuffer buffer; private ByteDenseNdArray(ByteDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java index 819d95d..54d337b 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java @@ -50,7 +50,7 @@ DenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensions) { } @Override - protected DataBuffer buffer() { + public DataBuffer buffer() { return buffer; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java index f54b8d0..d30c350 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java @@ -31,6 +31,11 @@ public static DoubleNdArray create(DoubleDataBuffer buffer, Shape shape) { return new DoubleDenseNdArray(buffer, shape); } + @Override + public DoubleDataBuffer buffer() { + return buffer; + } + @Override public double getDouble(long... indices) { return buffer.getDouble(positionOf(indices, true)); @@ -77,11 +82,6 @@ DoubleDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimen return new DoubleDenseNdArray((DoubleDataBuffer)buffer, dimensions); } - @Override - protected DoubleDataBuffer buffer() { - return buffer; - } - private final DoubleDataBuffer buffer; private DoubleDenseNdArray(DoubleDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java index 196b5ef..b164211 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java @@ -31,6 +31,11 @@ public static FloatNdArray create(FloatDataBuffer buffer, Shape shape) { return new FloatDenseNdArray(buffer, shape); } + @Override + public FloatDataBuffer buffer() { + return buffer; + } + @Override public float getFloat(long... indices) { return buffer.getFloat(positionOf(indices, true)); @@ -77,11 +82,6 @@ FloatDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensi return new FloatDenseNdArray((FloatDataBuffer) buffer, dimensions); } - @Override - public FloatDataBuffer buffer() { - return buffer; - } - private final FloatDataBuffer buffer; private FloatDenseNdArray(FloatDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java index a7af498..3cbd15e 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java @@ -31,6 +31,11 @@ public static IntNdArray create(IntDataBuffer buffer, Shape shape) { return new IntDenseNdArray(buffer, shape); } + @Override + public IntDataBuffer buffer() { + return buffer; + } + @Override public int getInt(long... indices) { return buffer.getInt(positionOf(indices, true)); @@ -77,11 +82,6 @@ IntDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensi return new IntDenseNdArray((IntDataBuffer)buffer, dimensions); } - @Override - protected IntDataBuffer buffer() { - return buffer; - } - private final IntDataBuffer buffer; private IntDenseNdArray(IntDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java index cd56dad..8c33528 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java @@ -31,6 +31,11 @@ public static LongNdArray create(LongDataBuffer buffer, Shape shape) { return new LongDenseNdArray(buffer, shape); } + @Override + public LongDataBuffer buffer() { + return buffer; + } + @Override public long getLong(long... indices) { return buffer.getLong(positionOf(indices, true)); @@ -77,11 +82,6 @@ LongDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimension return new LongDenseNdArray((LongDataBuffer)buffer, dimensions); } - @Override - protected LongDataBuffer buffer() { - return buffer; - } - private final LongDataBuffer buffer; private LongDenseNdArray(LongDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java index 291c01a..a44a81a 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java @@ -31,6 +31,11 @@ public static ShortNdArray create(ShortDataBuffer buffer, Shape shape) { return new ShortDenseNdArray(buffer, shape); } + @Override + public ShortDataBuffer buffer() { + return buffer; + } + @Override public short getShort(long... indices) { return buffer.getShort(positionOf(indices, true)); @@ -77,11 +82,6 @@ ShortDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensi return new ShortDenseNdArray((ShortDataBuffer)buffer, dimensions); } - @Override - protected ShortDataBuffer buffer() { - return buffer; - } - private final ShortDataBuffer buffer; private ShortDenseNdArray(ShortDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/BaseDenseNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/BaseDenseNdArrayInitializer.java new file mode 100644 index 0000000..3cde108 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/BaseDenseNdArrayInitializer.java @@ -0,0 +1,134 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.dense.initializer; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.impl.dense.AbstractDenseNdArray; +import org.tensorflow.ndarray.impl.initializer.AbstractNdArrayInitializer; +import org.tensorflow.ndarray.impl.sequence.PositionIterator; +import org.tensorflow.ndarray.initializer.BaseNdArrayInitializer; + +import java.util.Collection; +import java.util.Iterator; + +abstract class BaseDenseNdArrayInitializer, V extends AbstractDenseNdArray> extends AbstractNdArrayInitializer implements BaseNdArrayInitializer { + + @Override + public Scalars byScalars(long... coordinates) { + return new ScalarsImpl(coordinates); + } + + @Override + public Elements byElements(int dimensionIdx, long... coordinates) { + return new ElementsImpl(dimensionIdx, coordinates); + } + + /** + * Per-scalar initializer for dense arrays + */ + class ScalarsImpl implements Scalars { + + public Scalars to(long... coordinates) { + jumpTo(coordinates); + positionIterator = PositionIterator.create(array.dimensions(), coords); + return this; + } + + @Override + public Scalars put(T value) { + array.buffer().setObject(value, positionIterator.nextLong()); + next(); + return this; + } + + ScalarsImpl(long[] coordinates) { + resetTo(validateRankCoords(0, coordinates)); + positionIterator = PositionIterator.create(array.dimensions(), coords); + } + + protected PositionIterator positionIterator; + } + + /** + * Per-vector initializer for dense arrays + */ + class VectorsImpl implements Vectors { + + @Override + public Vectors to(long... coordinates) { + jumpTo(coordinates); + positionIterator = PositionIterator.create(array.dimensions(), coords); + return this; + } + + @Override + public Vectors put(Collection values) { + validateVectorLength(values.size()); + for (T v : values) { + array.buffer().setObject(v, positionIterator.nextLong()); + } + next(values.size()); + return this; + } + + protected void next(int numValues) { + BaseDenseNdArrayInitializer.this.next(); + // If the number of values is less that the size of a vector, we need to reposition our iterator + if (numValues < array.shape().get(-1)) { + positionIterator = PositionIterator.create(array.dimensions(), coords); + } + } + + VectorsImpl(long[] coordinates) { + resetTo(validateRankCoords(1, coordinates)); + positionIterator = PositionIterator.create(array.dimensions(), coords); + } + + protected PositionIterator positionIterator; + } + + /** + * Per-element initializer for dense arrays. + */ + class ElementsImpl implements Elements { + + @Override + public Elements to(long... coordinates) { + jumpTo(coordinates); + elementIterator = array.elementsAt(coords).iterator(); + return this; + } + + @Override + public Elements put(NdArray values) { + values.copyTo(elementIterator.next()); + next(); + return this; + } + + ElementsImpl(int dimensionIdx, long[] coordinates) { + resetTo(validateDimensionCoords(dimensionIdx, coordinates)); + elementIterator = array.elementsAt(coords).iterator(); + } + + protected Iterator elementIterator; + } + + protected BaseDenseNdArrayInitializer(V array) { + super(array); + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/DenseNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/DenseNdArrayInitializer.java new file mode 100644 index 0000000..fb00569 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/DenseNdArrayInitializer.java @@ -0,0 +1,63 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.dense.initializer; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.impl.dense.DenseNdArray; +import org.tensorflow.ndarray.initializer.NdArrayInitializer; + +import java.util.Collection; + +public class DenseNdArrayInitializer extends BaseDenseNdArrayInitializer, DenseNdArray> implements NdArrayInitializer { + + public DenseNdArrayInitializer(DenseNdArray array) { + super(array); + } + + @Override + public NdArrayInitializer.Vectors byVectors(long... coordinates) { + return new ObjectVectorsImpl(coordinates); + } + + /** + * Per-vector initializer for dense arrays. + */ + class ObjectVectorsImpl extends VectorsImpl implements NdArrayInitializer.Vectors { + + @Override + public NdArrayInitializer.Vectors to(long... coordinates) { + return (ObjectVectorsImpl) super.to(coordinates); + } + + @Override + public NdArrayInitializer.Vectors put(Collection values) { + return (ObjectVectorsImpl) super.put(values); + } + + @Override + public NdArrayInitializer.Vectors put(T... values) { + validateVectorLength(values.length); + array.buffer().offset(positionIterator.nextLong()).write(values); + next(values.length); + return this; + } + + ObjectVectorsImpl(long[] coordinates) { + super(coordinates); + } + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/DoubleDenseNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/DoubleDenseNdArrayInitializer.java new file mode 100644 index 0000000..440668f --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/DoubleDenseNdArrayInitializer.java @@ -0,0 +1,95 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.dense.initializer; + +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; +import org.tensorflow.ndarray.initializer.DoubleNdArrayInitializer; + +import java.util.Collection; + +public class DoubleDenseNdArrayInitializer extends BaseDenseNdArrayInitializer implements DoubleNdArrayInitializer { + + public DoubleDenseNdArrayInitializer(DoubleDenseNdArray array) { + super(array); + } + + @Override + public DoubleNdArrayInitializer.Scalars byScalars(long... coordinates) { + return new DoubleScalarsImpl(coordinates); + } + + @Override + public DoubleNdArrayInitializer.Vectors byVectors(long... coordinates) { + return new DoubleVectorsImpl(coordinates); + } + + /** + * Per-scalar initializer for dense double arrays. + */ + class DoubleScalarsImpl extends ScalarsImpl implements DoubleNdArrayInitializer.Scalars { + + @Override + public DoubleNdArrayInitializer.Scalars to(long... coordinates) { + return (DoubleScalarsImpl) super.to(coordinates); + } + + @Override + public DoubleNdArrayInitializer.Scalars put(Double value) { + return (DoubleScalarsImpl) super.put(value); + } + + @Override + public DoubleNdArrayInitializer.Scalars put(double value) { + array.buffer().setDouble(value, positionIterator.nextLong()); + next(); + return this; + } + + DoubleScalarsImpl(long[] coordinates) { + super(coordinates); + } + } + + /** + * Per-vector initializer for dense double arrays. + */ + class DoubleVectorsImpl extends VectorsImpl implements DoubleNdArrayInitializer.Vectors { + + @Override + public DoubleNdArrayInitializer.Vectors to(long... coordinates) { + return (DoubleVectorsImpl) super.to(coordinates); + } + + @Override + public DoubleNdArrayInitializer.Vectors put(Collection values) { + return (DoubleVectorsImpl) super.put(values); + } + + @Override + public DoubleNdArrayInitializer.Vectors put(double... values) { + validateVectorLength(values.length); + array.buffer().offset(positionIterator.nextLong()).write(values); + next(values.length); + return this; + } + + DoubleVectorsImpl(long[] coordinates) { + super(coordinates); + } + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java index 71d1677..7a6a6f6 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java @@ -189,6 +189,15 @@ public long positionOf(long[] coords) { return position; } + public boolean increment(long[] coords) { + for (int i = coords.length - 1; i >= 0; --i) { + if ((coords[i] = (coords[i] + 1) % shape.get(i)) > 0) { + return true; + } + } + return false; + } + /** * Succinct description of the shape meant for debugging. */ diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/initializer/AbstractNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/initializer/AbstractNdArrayInitializer.java new file mode 100644 index 0000000..f4efaab --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/initializer/AbstractNdArrayInitializer.java @@ -0,0 +1,101 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.initializer; + +import org.tensorflow.ndarray.impl.AbstractNdArray; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +import java.util.Arrays; + +public abstract class AbstractNdArrayInitializer> { + + protected static long[] validateNewCoords(long[] actualCoords, long[] newCoords) { + if (actualCoords != null) { + // Make sure we always move forward + boolean smaller = false; + + int i = actualCoords.length; + while (i > newCoords.length) { + // If current coords is of a higher rank, any non-zero value for a dimension missing in the new coordinates + // requires other dimensions to be higher + if (actualCoords[--i] > 0) { + smaller = true; + } + } + while (--i >= 0) { + if (newCoords[i] < actualCoords[i]) { + smaller = true; + } else if (smaller && newCoords[i] > actualCoords[i]) { + smaller = false; + } + } + if (smaller) { + throw new IllegalArgumentException("Cannot move backward during array initialization"); + } + } + return newCoords; + } + + protected long[] validateDimensionCoords(int dimensionIdx, long[] coords) { + if (coords == null || coords.length == 0) { + return new long[dimensionIdx + 1]; + } + if ((coords.length - 1) != dimensionIdx) { + throw new IllegalArgumentException(Arrays.toString(coords) + " are not valid coordinates for dimension " + + dimensionIdx + " in an array of shape " + array.shape()); + } + return Arrays.copyOf(coords, coords.length); + } + + protected long[] validateRankCoords(int elementRank, long[] coords) { + DimensionalSpace dimensions = array.dimensions(); + int dimensionIdx = dimensions.numDimensions() - elementRank - 1; + if (dimensionIdx < 0) { + throw new IllegalArgumentException("Cannot initialize array of shape " + array.shape() + " with elements of rank " + elementRank); + } + return validateDimensionCoords(dimensionIdx, coords); + } + + protected void validateVectorLength(int numValues) { + if (numValues > array.shape().get(-1)) { + throw new IllegalArgumentException("Vector values exceeds limit of " + array.shape().get(-1) + " elements"); + } + } + + protected void next() { + array.dimensions().increment(coords); + } + + protected void jumpTo(long[] newCoordinates) { + if (coords.length != newCoordinates.length) { + throw new IllegalArgumentException("New coordinates are not for the initialized dimension"); + } + resetTo(Arrays.copyOf(newCoordinates, newCoordinates.length)); + } + + protected void resetTo(long[] newCoords) { + coords = validateNewCoords(coords, newCoords); + } + + protected final V array; + + protected long[] coords = null; // must be explicitly reset + + protected AbstractNdArrayInitializer(V array) { + this.array = array; + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java deleted file mode 100644 index 8c9c9f8..0000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2020 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ======================================================================= - */ - -package org.tensorflow.ndarray.impl.sequence; - -final class CoordinatesIncrementor { - - boolean increment() { - for (int i = coords.length - 1; i >= 0; --i) { - if ((coords[i] = (coords[i] + 1) % shape[i]) > 0) { - return true; - } - } - return false; - } - - CoordinatesIncrementor(long[] shape, int dimensionIdx) { - this.shape = shape; - this.coords = new long[dimensionIdx + 1]; - } - - final long[] shape; - final long[] coords; -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/FastElementSequence.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/FastElementSequence.java index 92cebeb..2430030 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/FastElementSequence.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/FastElementSequence.java @@ -34,8 +34,12 @@ public final class FastElementSequence> implements NdArraySequence { public FastElementSequence(AbstractNdArray ndArray, int dimensionIdx, U element, DataBufferWindow elementWindow) { + this(ndArray, new long[dimensionIdx + 1], element, elementWindow); + } + + public FastElementSequence(AbstractNdArray ndArray, long[] startCoords, U element, DataBufferWindow elementWindow) { this.ndArray = ndArray; - this.dimensionIdx = dimensionIdx; + this.startCoords = startCoords; this.element = element; this.elementWindow = elementWindow; } @@ -47,7 +51,7 @@ public Iterator iterator() { @Override public void forEachIndexed(BiConsumer consumer) { - PositionIterator.createIndexed(ndArray.dimensions(), dimensionIdx).forEachIndexed((long[] coords, long position) -> { + PositionIterator.createIndexed(ndArray.dimensions(), startCoords).forEachIndexed((long[] coords, long position) -> { elementWindow.slideTo(position); consumer.accept(coords, element); }); @@ -55,7 +59,7 @@ public void forEachIndexed(BiConsumer consumer) { @Override public NdArraySequence asSlices() { - return new SlicingElementSequence(ndArray, dimensionIdx); + return new SlicingElementSequence(ndArray, startCoords); } private class SequenceIterator implements Iterator { @@ -71,11 +75,11 @@ public U next() { return element; } - private final PositionIterator positionIterator = PositionIterator.create(ndArray.dimensions(), dimensionIdx); + private final PositionIterator positionIterator = PositionIterator.create(ndArray.dimensions(), startCoords); } private final AbstractNdArray ndArray; - private final int dimensionIdx; + private final long[] startCoords; private final U element; private final DataBufferWindow elementWindow; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java index 80b3de6..39d2f67 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java @@ -17,6 +17,8 @@ package org.tensorflow.ndarray.impl.sequence; +import java.util.Arrays; + import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; class IndexedSequentialPositionIterator extends SequentialPositionIterator implements IndexedPositionIterator { @@ -24,28 +26,28 @@ class IndexedSequentialPositionIterator extends SequentialPositionIterator imple @Override public void forEachIndexed(CoordsLongConsumer consumer) { while (hasNext()) { - consumer.consume(coords, nextLong()); - incrementCoords(); + consumer.consume(coords, super.nextLong()); + dimensions.increment(coords); } } - private void incrementCoords() { - for (int i = coords.length - 1; i >= 0; --i) { - if (coords[i] < shape[i] - 1) { - coords[i] += 1L; - return; - } - coords[i] = 0L; - } + @Override + public long nextLong() { + long tmp = super.nextLong(); + dimensions.increment(coords); + return tmp; } IndexedSequentialPositionIterator(DimensionalSpace dimensions, int dimensionIdx) { - super(dimensions, dimensionIdx); - this.shape = dimensions.shape().asArray(); - this.coords = new long[dimensionIdx + 1]; - //this.coordsIncrementor = new CoordinatesIncrementor(dimensions.shape().asArray(), dimensionIdx); + this(dimensions, new long[dimensionIdx + 1]); + } + + IndexedSequentialPositionIterator(DimensionalSpace dimensions, long[] coords) { + super(dimensions, coords); + this.dimensions = dimensions; + this.coords = Arrays.copyOf(coords, coords.length); } - private final long[] shape; + private final DimensionalSpace dimensions; private final long[] coords; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java index 789474c..8f01d09 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java @@ -17,6 +17,7 @@ package org.tensorflow.ndarray.impl.sequence; +import java.util.Arrays; import java.util.NoSuchElementException; import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; @@ -33,7 +34,7 @@ public long nextLong() { throw new NoSuchElementException(); } long position = dimensions.positionOf(coords); - increment(); + incrementCoords(); return position; } @@ -41,28 +42,23 @@ public long nextLong() { public void forEachIndexed(CoordsLongConsumer consumer) { while (hasNext()) { consumer.consume(coords, dimensions.positionOf(coords)); - increment(); + incrementCoords(); } } - private void increment() { - if (!increment(coords, dimensions)) { + private void incrementCoords() { + if (!dimensions.increment(coords)) { coords = null; } } - static boolean increment(long[] coords, DimensionalSpace dimensions) { - for (int i = coords.length - 1; i >= 0; --i) { - if ((coords[i] = (coords[i] + 1) % dimensions.get(i).numElements()) > 0) { - return true; - } - } - return false; + NdPositionIterator(DimensionalSpace dimensions, int dimensionIdx) { + this(dimensions, new long[dimensionIdx + 1]); } - NdPositionIterator(DimensionalSpace dimensions, int dimensionIdx) { + NdPositionIterator(DimensionalSpace dimensions, long[] coords) { this.dimensions = dimensions; - this.coords = new long[dimensionIdx + 1]; + this.coords = Arrays.copyOf(coords, coords.length); } private final DimensionalSpace dimensions; diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java index 83ed940..8505112 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java @@ -17,9 +17,10 @@ package org.tensorflow.ndarray.impl.sequence; -import java.util.PrimitiveIterator; import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import java.util.PrimitiveIterator; + public interface PositionIterator extends PrimitiveIterator.OfLong { static PositionIterator create(DimensionalSpace dimensions, int dimensionIdx) { @@ -29,6 +30,13 @@ static PositionIterator create(DimensionalSpace dimensions, int dimensionIdx) { return new SequentialPositionIterator(dimensions, dimensionIdx); } + static PositionIterator create(DimensionalSpace dimensions, long... startCoords) { + if (dimensions.isSegmented()) { + return new NdPositionIterator(dimensions, startCoords); + } + return new SequentialPositionIterator(dimensions, startCoords); + } + static IndexedPositionIterator createIndexed(DimensionalSpace dimensions, int dimensionIdx) { if (dimensions.isSegmented()) { return new NdPositionIterator(dimensions, dimensionIdx); @@ -36,6 +44,13 @@ static IndexedPositionIterator createIndexed(DimensionalSpace dimensions, int di return new IndexedSequentialPositionIterator(dimensions, dimensionIdx); } + static IndexedPositionIterator createIndexed(DimensionalSpace dimensions, long... startCoords) { + if (dimensions.isSegmented()) { + return new NdPositionIterator(dimensions, startCoords); + } + return new IndexedSequentialPositionIterator(dimensions, startCoords); + } + static PositionIterator sequence(long stride, long end) { return new SequentialPositionIterator(stride, end); } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java index 65c6fc9..69caf17 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java @@ -17,14 +17,15 @@ package org.tensorflow.ndarray.impl.sequence; -import java.util.NoSuchElementException; import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import java.util.NoSuchElementException; + class SequentialPositionIterator implements PositionIterator { @Override public boolean hasNext() { - return index < end; + return pos < end; } @Override @@ -32,7 +33,7 @@ public long nextLong() { if (!hasNext()) { throw new NoSuchElementException(); } - return stride * index++; + return stride * pos++; } SequentialPositionIterator(DimensionalSpace dimensions, int dimensionIdx) { @@ -42,6 +43,12 @@ public long nextLong() { } this.stride = dimensions.get(dimensionIdx).elementSize(); this.end = size; + this.pos = 0L; + } + + SequentialPositionIterator(DimensionalSpace dimensions, long[] coords) { + this(dimensions, coords.length - 1); + this.pos = dimensions.positionOf(coords) / stride; } SequentialPositionIterator(long stride, long end) { @@ -51,5 +58,5 @@ public long nextLong() { private final long stride; private final long end; - private long index; + private long pos; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SlicingElementSequence.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SlicingElementSequence.java index 6fe8398..6d11968 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SlicingElementSequence.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SlicingElementSequence.java @@ -33,18 +33,26 @@ public final class SlicingElementSequence> implements NdArraySequence { public SlicingElementSequence(AbstractNdArray ndArray, int dimensionIdx) { - this(ndArray, dimensionIdx, ndArray.dimensions().from(dimensionIdx + 1)); + this(ndArray, new long[dimensionIdx + 1]); + } + + public SlicingElementSequence(AbstractNdArray ndArray, long[] startCoords) { + this(ndArray, startCoords, ndArray.dimensions().from(startCoords.length)); } public SlicingElementSequence(AbstractNdArray ndArray, int dimensionIdx, DimensionalSpace elementDimensions) { + this(ndArray, new long[dimensionIdx + 1], elementDimensions); + } + + public SlicingElementSequence(AbstractNdArray ndArray, long[] startCoords, DimensionalSpace elementDimensions) { this.ndArray = ndArray; - this.dimensionIdx = dimensionIdx; + this.startCoords = startCoords; this.elementDimensions = elementDimensions; } @Override public Iterator iterator() { - PositionIterator positionIterator = PositionIterator.create(ndArray.dimensions(), dimensionIdx); + PositionIterator positionIterator = PositionIterator.create(ndArray.dimensions(), startCoords); return new Iterator() { @Override @@ -61,7 +69,7 @@ public U next() { @Override public void forEachIndexed(BiConsumer consumer) { - PositionIterator.createIndexed(ndArray.dimensions(), dimensionIdx).forEachIndexed((long[] coords, long position) -> + PositionIterator.createIndexed(ndArray.dimensions(), startCoords).forEachIndexed((long[] coords, long position) -> consumer.accept(coords, ndArray.slice(position, elementDimensions)) ); } @@ -72,6 +80,6 @@ public NdArraySequence asSlices() { } private final AbstractNdArray ndArray; - private final int dimensionIdx; + private final long[] startCoords; private final DimensionalSpace elementDimensions; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/DoubleSparseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/DoubleSparseNdArray.java index 07a6d2a..4dba5ac 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/DoubleSparseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/DoubleSparseNdArray.java @@ -73,6 +73,24 @@ public class DoubleSparseNdArray extends AbstractSparseNdArray, V extends AbstractSparseNdArray> extends AbstractNdArrayInitializer implements BaseNdArrayInitializer { + + @Override + public Scalars byScalars(long... coordinates) { + return new ScalarsImpl(coordinates); + } + + @Override + public Elements byElements(int dimensionIdx, long... coordinates) { + return new ElementsImpl(dimensionIdx, coordinates); + } + + class ScalarsImpl implements Scalars { + + @Override + public Scalars to(long... coordinates) { + jumpTo(coordinates); + valueCoords = Arrays.copyOf(coordinates, array.shape().numDimensions()); + return this; + } + + @Override + public Scalars put(T value) { + if (value == null) { + throw new IllegalArgumentException("Scalar cannot be null"); + } + addValue(value); + next(); + return this; + } + + protected ScalarsImpl(long[] coordinates) { + resetTo(validateRankCoords(0, coordinates)); + } + } + + class VectorsImpl implements Vectors { + + @Override + public Vectors to(long... coordinates) { + jumpTo(coordinates); + valueCoords = Arrays.copyOf(coordinates, array.shape().numDimensions()); + return this; + } + + @Override + public Vectors put(Collection values) { + validateVectorLength(values.size()); + for (T value : values) { + addValue(value); + } + next(); + return this; + } + + protected VectorsImpl(long[] coordinates) { + resetTo(validateRankCoords(1, coordinates)); + } + } + + class ElementsImpl implements Elements { + + @Override + public Elements to(long... coordinates) { + jumpTo(coordinates); + valueCoords = Arrays.copyOf(coordinates, array.shape().numDimensions()); + return this; + } + + @Override + public Elements put(NdArray values) { + if (values.rank() != elementRank) { + throw new IllegalArgumentException("Values must be of element rank " + elementRank); + } + values.scalars().forEach(s -> { + addValue(s.getObject()); + }); + next(); + return this; + } + + protected ElementsImpl(int dimensionIdx, long[] coordinates) { + resetTo(validateDimensionCoords(dimensionIdx, coordinates)); + elementRank = array.shape().numDimensions() - dimensionIdx - 1; + } + + private final int elementRank; + } + + protected long valueCount = 0; + + protected long[] valueCoords; + + protected void addValue(T value) { + if (value != array.getDefaultValue()) { + array.getValues().setObject(value, valueCount); + array.getIndices().set(NdArrays.vectorOf(valueCoords), valueCount++); + } + array.dimensions().increment(valueCoords); + } + + BaseSparseNdArrayInitializer(V array) { + super(array); + this.valueCoords = new long[array.shape().numDimensions()]; + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/DoubleSparseNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/DoubleSparseNdArrayInitializer.java new file mode 100644 index 0000000..f7dcb84 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/DoubleSparseNdArrayInitializer.java @@ -0,0 +1,100 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.sparse.initializer; + +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.impl.sparse.DoubleSparseNdArray; +import org.tensorflow.ndarray.initializer.DoubleNdArrayInitializer; + +import java.util.Collection; + +public class DoubleSparseNdArrayInitializer extends BaseSparseNdArrayInitializer implements DoubleNdArrayInitializer { + + public DoubleSparseNdArrayInitializer(DoubleSparseNdArray array) { + super(array); + } + + @Override + public DoubleNdArrayInitializer.Scalars byScalars(long... coordinates) { + return new DoubleScalarsImpl(coordinates); + } + + @Override + public DoubleNdArrayInitializer.Vectors byVectors(long... coordinates) { + return new DoubleVectorsImpl(coordinates); + } + + private class DoubleScalarsImpl extends ScalarsImpl implements DoubleNdArrayInitializer.Scalars { + + @Override + public DoubleNdArrayInitializer.Scalars to(long... coordinates) { + return (DoubleScalarsImpl) super.to(coordinates); + } + + @Override + public DoubleNdArrayInitializer.Scalars put(Double value) { + return (DoubleScalarsImpl) super.put(value); + } + + @Override + public DoubleNdArrayInitializer.Scalars put(double scalar) { + addDoubleValue(scalar); + next(); + return this; + } + + private DoubleScalarsImpl(long[] coordinates) { + super(coordinates); + } + } + + private class DoubleVectorsImpl extends VectorsImpl implements DoubleNdArrayInitializer.Vectors { + + @Override + public DoubleNdArrayInitializer.Vectors to(long... coordinates) { + return (DoubleVectorsImpl) super.to(coordinates); + } + + @Override + public DoubleNdArrayInitializer.Vectors put(Collection values) { + return (DoubleVectorsImpl) super.put(values); + } + + @Override + public DoubleNdArrayInitializer.Vectors put(double... values) { + validateVectorLength(values.length); + for (double value : values) { + addDoubleValue(value); + } + next(); + return this; + } + + private DoubleVectorsImpl(long[] coordinates) { + super(coordinates); + } + } + + private void addDoubleValue(double value) { + if (value != array.getDefaultValue()) { + array.getValues().setDouble(value, valueCount); + array.getIndices().set(NdArrays.vectorOf(valueCoords), valueCount++); + } + array.dimensions().increment(valueCoords); + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/SparseNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/SparseNdArrayInitializer.java new file mode 100644 index 0000000..59182ac --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/SparseNdArrayInitializer.java @@ -0,0 +1,62 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.sparse.initializer; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; +import org.tensorflow.ndarray.initializer.NdArrayInitializer; + +import java.util.Collection; + +public class SparseNdArrayInitializer, V extends AbstractSparseNdArray> extends BaseSparseNdArrayInitializer implements NdArrayInitializer { + + public SparseNdArrayInitializer(V array) { + super(array); + } + + @Override + public NdArrayInitializer.Vectors byVectors(long... coordinates) { + return new ObjectVectorsImpl(coordinates); + } + + class ObjectVectorsImpl extends VectorsImpl implements NdArrayInitializer.Vectors { + + @Override + public NdArrayInitializer.Vectors to(long... coordinates) { + return (ObjectVectorsImpl) super.to(coordinates); + } + + @Override + public NdArrayInitializer.Vectors put(Collection values) { + return (ObjectVectorsImpl) super.put(values); + } + + @Override + public NdArrayInitializer.Vectors put(T... values) { + validateVectorLength(values.length); + for (T value : values) { + addValue(value); + } + next(); + return this; + } + + ObjectVectorsImpl(long[] coordinates) { + super(coordinates); + } + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/initializer/BaseNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/initializer/BaseNdArrayInitializer.java new file mode 100644 index 0000000..b999ff1 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/initializer/BaseNdArrayInitializer.java @@ -0,0 +1,223 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.initializer; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.buffer.DataBuffer; + +import java.util.Collection; + +/** + * Interface for initializing the data of a {@link NdArray} that has just been allocated. + * + *

The initializer API focuses on relative per-element initialization of the data of a newly allocated + * NdArray, which is more idiomatic than using output methods such as + * {@link NdArray#write(DataBuffer)} or {@link NdArray#copyTo(NdArray)}.

+ * + *

It also allows the initialization of {@link org.tensorflow.ndarray.SparseNdArray sparse arrays} before they + * become read-only.

+ * + * @param the type of data of the {@link NdArray} to initialize + */ +public interface BaseNdArrayInitializer { + + /** + * An API for initializing an {@link NdArray} using scalar values + * + * @param the type of data of the {@link NdArray} to initialize + */ + interface Scalars { + + /** + * Reset the position of the initializer in the NdArray so that the next values provided are + * written starting from the given {@code coordinates}. + * + *

Note that it is not possible to move backward within the array, {@code coordinates} must be equal or greater + * than the actual position of the initializer.

+ * + * @param coordinates position in the array + * @return this object + * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of a scalar + */ + Scalars to(long... coordinates); + + /** + * Sets the next scalar value in the array. + * + * @param value next scalar value + * @return this object + */ + Scalars put(T value); + } + + /** + * An API for initializing an {@link NdArray} using vectors. + * + * @param the type of data of the {@link NdArray} to initialize + */ + interface Vectors { + + /** + * Reset the position of the initializer in the NdArray so that the next vectors provided are written + * starting from the given {@code coordinates}. + * + *

Note that it is not possible to move backward within the array, {@code coordinates} must be equal or greater + * than the actual position of the initializer.

+ * + * @param coordinates position in the array + * @return this object + * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of a vector + */ + Vectors to(long... coordinates); + + /** + * Sets the next vector values in the array. + * + * @param values next vector values + * @return this object + * @throws IllegalArgumentException if {@code vector.length > array.shape().get(-1)} + */ + Vectors put(Collection values); + } + + /** + * An API for initializing an {@link NdArray} using n-dimensional elements (sub-arrays). + * + * @param the type of data of the {@link NdArray} to initialize + */ + interface Elements { + + /** + * Reset the position of the initializer in the NdArray so that the next elements provided are written + * starting from the given {@code coordinates}. + * + *

Note that it is not possible to move backward within the array, {@code coordinates} must be equal or greater + * than the actual position of the initializer.

+ * + * @param coordinates position in the array + * @return this object + * @throws IllegalArgumentException if {@code coordinates} are empty or are of a different dimension + */ + Elements to(long... coordinates); + + /** + * Sets the next element values in the array. + * + * @param values array containing the next element values + * @return this object + * @throws IllegalArgumentException if {@code element} is null or of the wrong rank + */ + Elements put(NdArray values); + } + + /** + * Per-scalar initialization of an {@link NdArray}. + * + *

Scalar initialization writes sequentially to an NdArray each individual values provided. Position + * can be reset to any scalar, across all dimensions.

+ * + *

If no {@code coordinates} are provided, the start position is the first scalar of this array.

+ * + * Example of usage: + *
{@code
+   *    NdArray array = NdArrays.ofObjects(String.class, Shape.of(3, 2), initializer -> {
+   *        initializer.byScalars()
+   *          .put("Cat")
+   *          .put("Dog")
+   *          .put("House")
+   *          .to(2, 1)
+   *          .put("Apple");
+   *    });
+   *    // -> [["Cat", "Dog"], ["House", null], [null, "Apple"]]
+   * }
+ * + * @param coordinates position of a scalar in the array to start initialization from, none for first scalar + * @return a {@link Scalars} instance + * @throws IllegalArgumentException if {@code coordinates} are set but are not one of a scalar + */ + Scalars byScalars(long... coordinates); + + /** + * Per-vector initialization of an {@link NdArray}. + * + *

Vector initialization writes sequentially provided values to vectors at the dimension n - 1 + * of an NdArray of rank n. The NdArray must therefore be of rank + * {@code > 0} (non-scalar).

+ * + *

Like in standard Java multidimensional arrays, it is possible to initialize partially a vector + * (i.e. having a number of values {@code < array.shape().get(-1)}). + * In such case, only the first values of the vector in the array will be initialized and the remaining will be + * left untouched (mostly defaulting to 0, depending on the type of buffer used to create the NdArray). + *

+ * + *

If no {@code coordinates} are provided, the start position is the the first vector of this array.

+ * + * Example of usage: + *
{@code
+   *    NdArray array = NdArrays.ofObjects(String.class, Shape.of(3, 2), initializer -> {
+   *        initializer.byVectors()
+   *          .put("Cat", "Dog")
+   *          .put("House") // partial initialization
+   *          .to(2)
+   *          .put("Orange", "Apple");
+   *    });
+   *    // -> [["Cat", "Dog"], ["House", null], ["Orange", "Apple"]]
+   * }
+ * + * @param coordinates position of a vector in the array to start initialization from, none for first vector + * @return a {@link Vectors} instance + * @throws IllegalArgumentException if array is of rank-0 or if {@code coordinates} are set but are not one of a vector + */ + Vectors byVectors(long... coordinates); + + /** + * Per-element initialization of an {@link NdArray}. + * + *

Element initialization writes sequentially values of provided arrays to elements at the dimension + * dimensionIdx of an NdArray. The provided arrays must be all the same rank, which matches + * the rank of the elements of the NdArray elements at this dimension.

+ * + *

If no {@code coordinates} are provided, the start position is the first element of dimension + * dimensionIdx of the array.

+ * + * Example of usage: + *
{@code
+   *    NdArray matrix = StdArrays.ndCopyOf(new String[][] {
+   *      { "Cat", "Apple" }, { "Dog", "Orange" }
+   *    });
+   *
+   *    NdArray array = NdArrays.ofObjects(String.class, Shape.of(4, 2, 2), initializer -> {
+   *        initializer.byElements(0)
+   *          .put(matrix)
+   *          .put(matrix)
+   *          .to(3)
+   *          .put(matrix);
+   *    });
+   *    // -> [[["Cat", "Apple"], ["Dog", "Orange"]],
+   *    //     [["Cat", "Apple"], ["Dog", "Orange"]],
+   *    //     [[null, null], [null, null]],
+   *    //     [["Cat", "Apple"], ["Dog", "Orange"]]]
+   * }
+ * + * @param dimensionIdx the index of the dimension being initialized + * @param coordinates position in the array to start from + * @return a {@link Elements} instance + * @throws IllegalArgumentException if {@code coordinates} are set but are not one of an element of the array or + * if {@code dimensionIdx >= array.shape().numDimensions()} + */ + Elements byElements(int dimensionIdx, long... coordinates); +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/initializer/DoubleNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/initializer/DoubleNdArrayInitializer.java new file mode 100644 index 0000000..04693f5 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/initializer/DoubleNdArrayInitializer.java @@ -0,0 +1,76 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.initializer; + +import org.tensorflow.ndarray.DoubleNdArray; + +import java.util.Collection; + +/** + * Specialization of the {@link BaseNdArrayInitializer} API for initializing arrays of doubles. + * + * @see BaseNdArrayInitializer + */ +public interface DoubleNdArrayInitializer extends BaseNdArrayInitializer { + + /** + * An API for initializing an {@link DoubleNdArray} using scalar values + */ + interface Scalars extends BaseNdArrayInitializer.Scalars { + + @Override + Scalars to(long... coordinates); + + @Override + Scalars put(Double value); + + /** + * Set the next double value in the array. + * + * @param value next scalar value + * @return this object + */ + Scalars put(double value); + } + + /** + * An API for initializing an {@link DoubleNdArray} using vectors. + */ + interface Vectors extends BaseNdArrayInitializer.Vectors { + + @Override + Vectors to(long... coordinates); + + @Override + Vectors put(Collection values); + + /** + * Set the next vector double values in the array. + * + * @param values next vector values + * @return this object + * @throws IllegalArgumentException if {@code vector.length > array.shape().get(-1)} + */ + Vectors put(double... values); + } + + @Override + Scalars byScalars(long... coordinates); + + @Override + Vectors byVectors(long... coordinates); +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/initializer/NdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/initializer/NdArrayInitializer.java new file mode 100644 index 0000000..60346b0 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/initializer/NdArrayInitializer.java @@ -0,0 +1,55 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.initializer; + +import java.util.Collection; + +/** + * Specialization of the {@link BaseNdArrayInitializer} API for initializing arrays of objects. + * + * @see BaseNdArrayInitializer + * @param type of objects to initialize + */ +public interface NdArrayInitializer extends BaseNdArrayInitializer { + + /** + * {@inheritDoc} + */ + interface Vectors extends BaseNdArrayInitializer.Vectors { + + @Override + Vectors to(long... coordinates); + + @Override + Vectors put(Collection values); + + /** + * Set the next vector values in the array. + * + * @param values next vector values + * @return this object + * @throws IllegalArgumentException if {@code vector.length > array.shape().get(-1)} + */ + Vectors put(T... values); + } + + /** + * {@inheritDoc} + */ + @Override + Vectors byVectors(long... coordinates); +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/initializer/DoubleDenseNdArrayInitializerTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/initializer/DoubleDenseNdArrayInitializerTest.java new file mode 100644 index 0000000..ef2c5a8 --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/initializer/DoubleDenseNdArrayInitializerTest.java @@ -0,0 +1,33 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.dense.initializer; + +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.initializer.DoubleNdArrayInitializer; +import org.tensorflow.ndarray.impl.initializer.DoubleNdArrayInitializerTestBase; + +import java.util.function.Consumer; + +public class DoubleDenseNdArrayInitializerTest extends DoubleNdArrayInitializerTestBase { + + @Override + protected DoubleNdArray newArray(Shape shape, long numValues, Consumer init) { + return NdArrays.ofDoubles(shape, init); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/initializer/StringDenseNdArrayInitializerTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/initializer/StringDenseNdArrayInitializerTest.java new file mode 100644 index 0000000..2e9012a --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/initializer/StringDenseNdArrayInitializerTest.java @@ -0,0 +1,33 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.dense.initializer; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.initializer.NdArrayInitializer; +import org.tensorflow.ndarray.impl.initializer.StringNdArrayInitializerTestBase; + +import java.util.function.Consumer; + +public class StringDenseNdArrayInitializerTest extends StringNdArrayInitializerTestBase { + + @Override + protected NdArray newArray(Shape shape, long numValues, Consumer> init) { + return NdArrays.ofObjects(String.class, shape, init); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/AbstractNdArrayInitializerTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/AbstractNdArrayInitializerTest.java new file mode 100644 index 0000000..38a7ac1 --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/AbstractNdArrayInitializerTest.java @@ -0,0 +1,68 @@ +package org.tensorflow.ndarray.impl.initializer; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class AbstractNdArrayInitializerTest { + + @Test + public void newCoordsMovingForwardAreValid() { + AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0, 0}, new long[]{1, 0, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0, 0}, new long[]{0, 1, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0, 0}, new long[]{0, 0, 1}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0, 0}, new long[]{1, 0, 1}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 0, 0}, new long[]{1, 0, 1}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 1, 0}, new long[]{1, 2, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 1, 0}, new long[]{1, 1, 1}); + } + + @Test + public void newCoordsOfLowerRankMovingForwardAreValid() { + AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0, 0}, new long[]{1, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0}, new long[]{1}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 1, 0}, new long[]{1, 2}); + } + + @Test + public void newCoordsOfHigherRankMovingForwardAreValid() { + AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0}, new long[]{1, 0, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0}, new long[]{0, 1, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1}, new long[]{1, 2}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1}, new long[]{2, 0}); + } + + @Test + public void newCoordsEqualsToActualAreValid() { + AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 0, 0}, new long[]{1, 0, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 0, 0}, new long[]{1, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 0, 0}, new long[]{1}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1}, new long[]{1, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{0}, new long[]{0, 0, 0}); + } + + @Test + public void newCoordsMovingBackwardAreInvalid() { + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 0, 0}, new long[]{0, 0, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 1, 0}, new long[]{0, 0, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0, 1}, new long[]{0, 0, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 0, 1}, new long[]{0, 0, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 0, 1}, new long[]{1, 0, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 2, 0}, new long[]{1, 1, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 1, 1}, new long[]{1, 1, 0})); + } + + @Test + public void newCoordsLowerRankMovingBackwardAreInvalid() { + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 0, 0}, new long[]{0, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 1, 0}, new long[]{0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{2, 2, 0}, new long[]{2})); + } + + @Test + public void newCoordsHigherRankMovingBackwardAreInvalid() { + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{1}, new long[]{0, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 1}, new long[]{0, 0, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{2, 2}, new long[]{2, 0, 0})); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/DoubleNdArrayInitializerTestBase.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/DoubleNdArrayInitializerTestBase.java new file mode 100644 index 0000000..48d59c5 --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/DoubleNdArrayInitializerTestBase.java @@ -0,0 +1,148 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.initializer; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.initializer.DoubleNdArrayInitializer; + +import java.util.function.Consumer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public abstract class DoubleNdArrayInitializerTestBase { + + protected abstract DoubleNdArray newArray(Shape shape, long numValues, Consumer init); + + @Test + public void initializeNdArrayByScalars() { + DoubleNdArray array = newArray(Shape.of(3, 2, 3), 15, init -> { + init + .byScalars() + .put(0.0) + .put(0.1) + .put(0.2) + .put(0.3) + .put(0.4) + .put(0.5) + .put(1.0) + .put(1.1) + .put(1.2) + .to(2, 0, 0) + .put(2.0) + .put(2.1) + .put(2.2) + .put(2.3) + .put(2.4) + .put(2.5); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][]{ + {{0.0, 0.1, 0.2}, {0.3, 0.4, 0.5}}, + {{1.0, 1.1, 1.2}, {0.0, 0.0, 0.0}}, + {{2.0, 2.1, 2.2}, {2.3, 2.4, 2.5}} + }), array); + + array = newArray(Shape.of(3, 2), 4, init -> { + init + .byScalars() + .put(10.0) + .put(20.0) + .put(30.0) + .to(2, 1) + .put(40.0); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][]{{10.0, 20.0}, {30.0, 0.0}, {0.0, 40.0}}), array); + } + + @Test + public void initializeNdArrayByVectors() { + DoubleNdArray array = newArray(Shape.of(3, 2, 3), 15, init -> { + init + .byVectors() + .put(0.0, 0.1, 0.2) + .put(0.3, 0.4, 0.5) + .put(1.0, 1.1, 1.2) + .to(2, 0) + .put(2.0, 2.1, 2.2) + .put(2.3, 2.4, 2.5); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][]{ + {{0.0, 0.1, 0.2}, {0.3, 0.4, 0.5}}, + {{1.0, 1.1, 1.2}, {0.0, 0.0, 0.0}}, + {{2.0, 2.1, 2.2}, {2.3, 2.4, 2.5}} + }), array); + + array = newArray(Shape.of(3, 2), 5, init -> { + init + .byVectors() + .put(10.0, 20.0) + .put(30.0) + .to(2) + .put(40.0, 50.0); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][]{{10.0, 20.0}, {30.0, 0.0}, {40.0, 50.0}}), array); + } + + @Test + public void initializeNdArrayByElements() { + DoubleNdArray array = newArray(Shape.of(3, 2, 3), 12, init -> { + init + .byElements(0) + .put(StdArrays.ndCopyOf(new double[][]{ + {0.0, 0.1, 0.2}, + {0.3, 0.4, 0.5} + })) + .to(2) + .put(StdArrays.ndCopyOf(new double[][]{ + {2.0, 2.1, 2.2}, + {2.3, 2.4, 2.5} + })); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][]{ + {{0.0, 0.1, 0.2}, {0.3, 0.4, 0.5}}, + {{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}, + {{2.0, 2.1, 2.2}, {2.3, 2.4, 2.5}} + }), array); + + DoubleNdArray vector = NdArrays.vectorOf(10.0, 20.0); + + array = newArray(Shape.of(4, 2, 2), 8, init -> { + init + .byElements(1) + .put(vector) + .put(vector) + .put(vector) + .to(3, 1) + .put(vector); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][]{ + {{10.0, 20.0}, {10.0, 20.0}}, + {{10.0, 20.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {10.0, 20.0}} + }), array); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/StringNdArrayInitializerTestBase.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/StringNdArrayInitializerTestBase.java new file mode 100644 index 0000000..1bfd7dd --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/StringNdArrayInitializerTestBase.java @@ -0,0 +1,146 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.initializer; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.initializer.NdArrayInitializer; + +import java.util.function.Consumer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public abstract class StringNdArrayInitializerTestBase { + + protected abstract NdArray newArray(Shape shape, long numValues, Consumer> init); + + @Test + public void initializeNdArrayByScalars() { + var array = newArray(Shape.of(3, 2, 3), 15, init -> { + init + .byScalars() + .put("0.0") + .put("0.1") + .put("0.2") + .put("0.3") + .put("0.4") + .put("0.5") + .put("1.0") + .put("1.1") + .put("1.2") + .to(2, 0, 0) + .put("2.0") + .put("2.1") + .put("2.2") + .put("2.3") + .put("2.4") + .put("2.5"); + }); + + assertEquals(StdArrays.ndCopyOf(new String[][][]{ + {{"0.0", "0.1", "0.2"}, {"0.3", "0.4", "0.5"}}, + {{"1.0", "1.1", "1.2"}, {null, null, null}}, + {{"2.0", "2.1", "2.2"}, {"2.3", "2.4", "2.5"}} + }), array); + + array = newArray(Shape.of(3, 2), 4, init -> { + init + .byScalars() + .put("10.0") + .put("20.0") + .put("30.0") + .to(2, 1) + .put("40.0"); + }); + + assertEquals(StdArrays.ndCopyOf(new String[][]{{"10.0", "20.0"}, {"30.0", null}, {null, "40.0"}}), array); + } + + @Test + public void initializeNdArrayByVectors() { + var array = newArray(Shape.of(3, 2, 3), 15, init -> { + init + .byVectors() + .put("0.0", "0.1", "0.2").put("0.3", "0.4", "0.5") + .put("1.0", "1.1", "1.2") + .to(2, 0) + .put("2.0", "2.1", "2.2").put("2.3", "2.4", "2.5"); + }); + + assertEquals(StdArrays.ndCopyOf(new String[][][]{ + {{"0.0", "0.1", "0.2"}, {"0.3", "0.4", "0.5"}}, + {{"1.0", "1.1", "1.2"}, {null, null, null}}, + {{"2.0", "2.1", "2.2"}, {"2.3", "2.4", "2.5"}} + }), array); + + array = newArray(Shape.of(3, 2), 5, init -> { + init + .byVectors() + .put("10.0", "20.0") + .put("30.0") + .to(2) + .put("40.0", "50.0"); + }); + + assertEquals(StdArrays.ndCopyOf(new String[][]{{"10.0", "20.0"}, {"30.0", null}, {"40.0", "50.0"}}), array); + } + + @Test + public void initializeNdArrayByElements() { + var array = newArray(Shape.of(3, 2, 3), 12, init -> { + init + .byElements(0) + .put(StdArrays.ndCopyOf(new String[][]{ + {"0.0", "0.1", "0.2"}, + {"0.3", "0.4", "0.5"} + })) + .to(2) + .put(StdArrays.ndCopyOf(new String[][]{ + {"2.0", "2.1", "2.2"}, + {"2.3", "2.4", "2.5"} + })); + }); + + assertEquals(StdArrays.ndCopyOf(new String[][][]{ + {{"0.0", "0.1", "0.2"}, {"0.3", "0.4", "0.5"}}, + {{null, null, null}, {null, null, null}}, + {{"2.0", "2.1", "2.2"}, {"2.3", "2.4", "2.5"}} + }), array); + + var vector = NdArrays.vectorOfObjects("10.0", "20.0"); + + array = newArray(Shape.of(4, 2, 2), 8, init -> { + init + .byElements(1) + .put(vector) + .put(vector) + .put(vector) + .to(3, 1) + .put(vector); + }); + + assertEquals(StdArrays.ndCopyOf(new String[][][]{ + {{"10.0", "20.0"}, {"10.0", "20.0"}}, + {{"10.0", "20.0"}, {null, null}}, + {{null, null}, {null, null}}, + {{null, null}, {"10.0", "20.0"}} + }), array); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sequence/ElementSequenceTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sequence/ElementSequenceTest.java index bad7840..3ad462d 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sequence/ElementSequenceTest.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sequence/ElementSequenceTest.java @@ -98,7 +98,7 @@ public void slicingElementSequenceReturnsUniqueInstances() { public void fastElementSequenceReturnsSameInstance() { IntNdArray array = NdArrays.ofInts(Shape.of(2, 3, 2)); IntNdArray element = array.get(0); - NdArraySequence sequence = new FastElementSequence( + NdArraySequence sequence = new FastElementSequence( (AbstractNdArray) array, 1, element, mockDataBufferWindow(2)); sequence.forEach(e -> { if (e != element) { diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/initializer/DoubleSparseNdArrayInitializerTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/initializer/DoubleSparseNdArrayInitializerTest.java new file mode 100644 index 0000000..b6b57cd --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/initializer/DoubleSparseNdArrayInitializerTest.java @@ -0,0 +1,33 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.sparse.initializer; + +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.initializer.DoubleNdArrayInitializer; +import org.tensorflow.ndarray.impl.initializer.DoubleNdArrayInitializerTestBase; + +import java.util.function.Consumer; + +public class DoubleSparseNdArrayInitializerTest extends DoubleNdArrayInitializerTestBase { + + @Override + protected DoubleNdArray newArray(Shape shape, long numValues, Consumer init) { + return NdArrays.sparseOfDoubles(shape, numValues, init); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/initializer/StringSparseNdArrayInitializerTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/initializer/StringSparseNdArrayInitializerTest.java new file mode 100644 index 0000000..0f2e8fa --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/initializer/StringSparseNdArrayInitializerTest.java @@ -0,0 +1,33 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.sparse.initializer; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.initializer.NdArrayInitializer; +import org.tensorflow.ndarray.impl.initializer.StringNdArrayInitializerTestBase; + +import java.util.function.Consumer; + +public class StringSparseNdArrayInitializerTest extends StringNdArrayInitializerTestBase { + + @Override + protected NdArray newArray(Shape shape, long numValues, Consumer> init) { + return NdArrays.sparseOfObjects(String.class, shape, numValues, init); + } +}