From 4ad79968b3feac7df48db79c803cf1a1559484ca Mon Sep 17 00:00:00 2001 From: karllessard Date: Sun, 4 Dec 2022 22:50:11 -0500 Subject: [PATCH 1/3] Skeleton for the NdArray hydration API --- .../java/org/tensorflow/ndarray/NdArrays.java | 26 ++++ .../hydrator/DoubleNdArrayHydrator.java | 30 ++++ .../ndarray/hydrator/NdArrayHydrator.java | 35 +++++ .../impl/dense/AbstractDenseNdArray.java | 4 +- .../impl/dense/BooleanDenseNdArray.java | 10 +- .../ndarray/impl/dense/ByteDenseNdArray.java | 10 +- .../ndarray/impl/dense/DenseNdArray.java | 2 +- .../impl/dense/DoubleDenseNdArray.java | 10 +- .../ndarray/impl/dense/FloatDenseNdArray.java | 10 +- .../ndarray/impl/dense/IntDenseNdArray.java | 10 +- .../ndarray/impl/dense/LongDenseNdArray.java | 10 +- .../ndarray/impl/dense/ShortDenseNdArray.java | 10 +- .../dense/hydrator/DenseNdArrayHydrator.java | 130 ++++++++++++++++ .../hydrator/DoubleDenseNdArrayHydrator.java | 66 ++++++++ .../impl/sequence/CoordinatesIncrementor.java | 20 ++- .../IndexedSequentialPositionIterator.java | 3 +- .../impl/sequence/NdPositionIterator.java | 5 + .../impl/sequence/PositionIterator.java | 11 ++ .../sequence/SequentialPositionIterator.java | 13 +- .../impl/sparse/AbstractSparseNdArray.java | 11 +- .../impl/sparse/DoubleSparseNdArray.java | 18 +++ .../hydrator/DoubleSparseNdArrayHydrator.java | 75 +++++++++ .../hydrator/SparseNdArrayHydrator.java | 146 ++++++++++++++++++ .../hydrator/DenseNdArrayHydratorTest.java | 87 +++++++++++ .../hydrator/SparseNdArrayHydratorTest.java | 85 ++++++++++ 25 files changed, 783 insertions(+), 54 deletions(-) create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydrator.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/hydrator/NdArrayHydrator.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydrator.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydrator.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydrator.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydrator.java create mode 100644 ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydratorTest.java create mode 100644 ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydratorTest.java diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java b/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java index d79b781..bf75e8f 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java @@ -16,6 +16,7 @@ */ package org.tensorflow.ndarray; +import java.util.function.Consumer; import org.tensorflow.ndarray.buffer.BooleanDataBuffer; import org.tensorflow.ndarray.buffer.ByteDataBuffer; import org.tensorflow.ndarray.buffer.DataBuffer; @@ -25,6 +26,7 @@ import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.buffer.LongDataBuffer; import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; import org.tensorflow.ndarray.impl.dense.BooleanDenseNdArray; import org.tensorflow.ndarray.impl.dense.ByteDenseNdArray; import org.tensorflow.ndarray.impl.dense.DenseNdArray; @@ -33,6 +35,7 @@ import org.tensorflow.ndarray.impl.dense.IntDenseNdArray; import org.tensorflow.ndarray.impl.dense.LongDenseNdArray; import org.tensorflow.ndarray.impl.dense.ShortDenseNdArray; +import org.tensorflow.ndarray.impl.dense.hydrator.DoubleDenseNdArrayHydrator; import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; import org.tensorflow.ndarray.impl.sparse.BooleanSparseNdArray; import org.tensorflow.ndarray.impl.sparse.ByteSparseNdArray; @@ -41,6 +44,7 @@ import org.tensorflow.ndarray.impl.sparse.IntSparseNdArray; import org.tensorflow.ndarray.impl.sparse.LongSparseNdArray; import org.tensorflow.ndarray.impl.sparse.ShortSparseNdArray; +import org.tensorflow.ndarray.impl.sparse.hydrator.DoubleSparseNdArrayHydrator; /** Utility class for instantiating {@link NdArray} objects. */ public final class NdArrays { @@ -555,6 +559,20 @@ public static DoubleNdArray ofDoubles(Shape shape) { return wrap(shape, DataBuffers.ofDoubles(shape.size())); } + /** + * Creates an N-dimensional array of doubles of the given shape, with data hydration + * + * @param shape shape of the array + * @param hydrate initialize the data of the created array, using a hydrator + * @return new double N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static DoubleNdArray ofDoubles(Shape shape, Consumer hydrate) { + DoubleDenseNdArray array = (DoubleDenseNdArray)ofDoubles(shape); + hydrate.accept(new DoubleDenseNdArrayHydrator(array)); + return array; + } + /** * Wraps a buffer in a double N-dimensional array of a given shape. * @@ -568,6 +586,14 @@ public static DoubleNdArray wrap(Shape shape, DoubleDataBuffer buffer) { return DoubleDenseNdArray.create(buffer, shape); } + public static DoubleSparseNdArray sparseOfDoubles(long numValues, Shape shape, Consumer hydrate) { + LongNdArray indices = ofLongs(Shape.of(numValues, shape.numDimensions())); + DoubleNdArray values = ofDoubles(Shape.of(numValues)); + DoubleSparseNdArray array = DoubleSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + hydrate.accept(new DoubleSparseNdArrayHydrator(array)); + return array; + } + /** * Creates a Sparse array of double values with a default value of zero * diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydrator.java new file mode 100644 index 0000000..cac8472 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydrator.java @@ -0,0 +1,30 @@ +package org.tensorflow.ndarray.hydrator; + +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; + +public interface DoubleNdArrayHydrator extends NdArrayHydrator { + + interface Scalars extends NdArrayHydrator.Scalars { + + @Override + Scalars at(long... coordinates); + + Scalars put(double scalar); + } + + interface Vectors extends NdArrayHydrator.Vectors { + + @Override + Vectors at(long... coordinates); + + Vectors put(double... vector); + } + + @Override + Scalars byScalars(long... coordinates); + + @Override + Vectors byVectors(long... coordinates); +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/NdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/NdArrayHydrator.java new file mode 100644 index 0000000..6ecb398 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/NdArrayHydrator.java @@ -0,0 +1,35 @@ +package org.tensorflow.ndarray.hydrator; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; + +public interface NdArrayHydrator { + + interface Scalars { + + > U at(long... coordinates); + + > U putObject(T scalar); + } + + interface Vectors { + + > U at(long... coordinates); + + > U putObjects(T... vector); + } + + interface Elements { + + > U at(long... coordinates); + + > U put(NdArray vector); + } + + > U byScalars(long... coordinates); + + > U byVectors(long... coordinates); + + > U byElements(long... coordinates); +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java index 30af952..9fc353c 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java @@ -31,6 +31,8 @@ @SuppressWarnings("unchecked") public abstract class AbstractDenseNdArray> extends AbstractNdArray { + abstract public DataBuffer buffer(); + @Override public NdArraySequence elements(int dimensionIdx) { if (dimensionIdx >= shape().numDimensions()) { @@ -145,8 +147,6 @@ protected AbstractDenseNdArray(DimensionalSpace dimensions) { super(dimensions); } - abstract protected DataBuffer buffer(); - abstract U instantiate(DataBuffer buffer, DimensionalSpace dimensions); long positionOf(long[] coords, boolean isValue) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java index 0764146..a3caf40 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java @@ -31,6 +31,11 @@ public static BooleanNdArray create(BooleanDataBuffer buffer, Shape shape) { return new BooleanDenseNdArray(buffer, shape); } + @Override + public BooleanDataBuffer buffer() { + return buffer; + } + @Override public boolean getBoolean(long... indices) { return buffer.getBoolean(positionOf(indices, true)); @@ -77,11 +82,6 @@ BooleanDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dim return new BooleanDenseNdArray((BooleanDataBuffer)buffer, dimensions); } - @Override - protected BooleanDataBuffer buffer() { - return buffer; - } - private final BooleanDataBuffer buffer; private BooleanDenseNdArray(BooleanDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java index 172432b..fa3b722 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java @@ -31,6 +31,11 @@ public static ByteNdArray create(ByteDataBuffer buffer, Shape shape) { return new ByteDenseNdArray(buffer, shape); } + @Override + public ByteDataBuffer buffer() { + return buffer; + } + @Override public byte getByte(long... indices) { return buffer.getByte(positionOf(indices, true)); @@ -77,11 +82,6 @@ ByteDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimension return new ByteDenseNdArray((ByteDataBuffer)buffer, dimensions); } - @Override - protected ByteDataBuffer buffer() { - return buffer; - } - private final ByteDataBuffer buffer; private ByteDenseNdArray(ByteDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java index 819d95d..54d337b 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java @@ -50,7 +50,7 @@ DenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensions) { } @Override - protected DataBuffer buffer() { + public DataBuffer buffer() { return buffer; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java index f54b8d0..d30c350 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java @@ -31,6 +31,11 @@ public static DoubleNdArray create(DoubleDataBuffer buffer, Shape shape) { return new DoubleDenseNdArray(buffer, shape); } + @Override + public DoubleDataBuffer buffer() { + return buffer; + } + @Override public double getDouble(long... indices) { return buffer.getDouble(positionOf(indices, true)); @@ -77,11 +82,6 @@ DoubleDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimen return new DoubleDenseNdArray((DoubleDataBuffer)buffer, dimensions); } - @Override - protected DoubleDataBuffer buffer() { - return buffer; - } - private final DoubleDataBuffer buffer; private DoubleDenseNdArray(DoubleDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java index 196b5ef..b164211 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java @@ -31,6 +31,11 @@ public static FloatNdArray create(FloatDataBuffer buffer, Shape shape) { return new FloatDenseNdArray(buffer, shape); } + @Override + public FloatDataBuffer buffer() { + return buffer; + } + @Override public float getFloat(long... indices) { return buffer.getFloat(positionOf(indices, true)); @@ -77,11 +82,6 @@ FloatDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensi return new FloatDenseNdArray((FloatDataBuffer) buffer, dimensions); } - @Override - public FloatDataBuffer buffer() { - return buffer; - } - private final FloatDataBuffer buffer; private FloatDenseNdArray(FloatDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java index a7af498..3cbd15e 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java @@ -31,6 +31,11 @@ public static IntNdArray create(IntDataBuffer buffer, Shape shape) { return new IntDenseNdArray(buffer, shape); } + @Override + public IntDataBuffer buffer() { + return buffer; + } + @Override public int getInt(long... indices) { return buffer.getInt(positionOf(indices, true)); @@ -77,11 +82,6 @@ IntDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensi return new IntDenseNdArray((IntDataBuffer)buffer, dimensions); } - @Override - protected IntDataBuffer buffer() { - return buffer; - } - private final IntDataBuffer buffer; private IntDenseNdArray(IntDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java index cd56dad..8c33528 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java @@ -31,6 +31,11 @@ public static LongNdArray create(LongDataBuffer buffer, Shape shape) { return new LongDenseNdArray(buffer, shape); } + @Override + public LongDataBuffer buffer() { + return buffer; + } + @Override public long getLong(long... indices) { return buffer.getLong(positionOf(indices, true)); @@ -77,11 +82,6 @@ LongDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimension return new LongDenseNdArray((LongDataBuffer)buffer, dimensions); } - @Override - protected LongDataBuffer buffer() { - return buffer; - } - private final LongDataBuffer buffer; private LongDenseNdArray(LongDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java index 291c01a..a44a81a 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java @@ -31,6 +31,11 @@ public static ShortNdArray create(ShortDataBuffer buffer, Shape shape) { return new ShortDenseNdArray(buffer, shape); } + @Override + public ShortDataBuffer buffer() { + return buffer; + } + @Override public short getShort(long... indices) { return buffer.getShort(positionOf(indices, true)); @@ -77,11 +82,6 @@ ShortDenseNdArray instantiate(DataBuffer buffer, DimensionalSpace dimensi return new ShortDenseNdArray((ShortDataBuffer)buffer, dimensions); } - @Override - protected ShortDataBuffer buffer() { - return buffer; - } - private final ShortDataBuffer buffer; private ShortDenseNdArray(ShortDataBuffer buffer, DimensionalSpace dimensions) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydrator.java new file mode 100644 index 0000000..f764b91 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydrator.java @@ -0,0 +1,130 @@ +package org.tensorflow.ndarray.impl.dense.hydrator; + +import java.util.Arrays; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.hydrator.NdArrayHydrator; +import org.tensorflow.ndarray.impl.dense.AbstractDenseNdArray; +import org.tensorflow.ndarray.impl.sequence.CoordinatesIncrementor; +import org.tensorflow.ndarray.impl.sequence.PositionIterator; + +class DenseNdArrayHydrator implements NdArrayHydrator { + + public DenseNdArrayHydrator(AbstractDenseNdArray array) { + this.denseArray = array; + } + + @Override + public Scalars byScalars(long... coordinates) { + return new ScalarsImpl(coordinates); + } + + @Override + public Vectors byVectors(long... coordinates) { + return new VectorsImpl(coordinates); + } + + @Override + public Elements byElements(long... coordinates) { + return new ElementsImpl(coordinates); + } + + protected class ScalarsImpl implements Scalars { + + @Override + public > U at(long... coordinates) { + if (coordinates == null || coordinates.length != denseArray.shape().numDimensions()) { + throw new IllegalArgumentException(Arrays.toString(coordinates) + " are not valid scalar coordinates for an array of shape " + denseArray + .shape()); + } + positionIterator = PositionIterator.create(denseArray.dimensions(), coordinates); + return (U) this; + } + + @Override + public > U putObject(T scalar) { + buffer().setObject(scalar, positionIterator.next()); + return (U) this; + } + + protected ScalarsImpl(long[] coords) { + if (coords == null || coords.length == 0) { + positionIterator = PositionIterator.create(denseArray.dimensions(), denseArray.shape().numDimensions() - 1); + } else { + at(coords); + } + } + + protected PositionIterator positionIterator; + } + + protected class VectorsImpl implements Vectors { + + @Override + public > U at(long... coordinates) { + if (coordinates == null || coordinates.length != denseArray.shape().numDimensions() - 1) { + throw new IllegalArgumentException(Arrays.toString(coordinates) + " are not valid vector coordinates for an array of shape " + denseArray + .shape()); + } + positionIterator = PositionIterator.create(denseArray.dimensions(), coordinates); + return (U) this; + } + + @Override + public > U putObjects(T... vector) { + if (vector == null || vector.length > denseArray.shape().get(-1)) { + throw new IllegalArgumentException("Vector should not be null nor exceed " + denseArray.shape().get(-1) + " elements"); + } + buffer().offset(positionIterator.next()).write(vector); + return (U) this; + } + + protected VectorsImpl(long[] coords) { + if (denseArray.shape().numDimensions() < 1) { + throw new IllegalArgumentException("Cannot hydrate a scalar with vectors"); + } + if (coords == null || coords.length == 0) { + positionIterator = PositionIterator.create(denseArray.dimensions(), denseArray.shape().numDimensions() - 2); + } else { + at(coords); + } + } + + protected PositionIterator positionIterator; + } + + protected class ElementsImpl implements Elements { + + @Override + public > U at(long... coordinates) { + if (coordinates == null || coordinates.length == 0 || coordinates.length > denseArray.shape().numDimensions()) { + throw new IllegalArgumentException(Arrays.toString(coordinates) + " are not valid coordinates for an array of shape " + denseArray + .shape()); + } + this.coordinates = new CoordinatesIncrementor(denseArray.shape().asArray(), coordinates); + return (U) this; + } + + @Override + public > U put(NdArray array) { + array.copyTo(denseArray.get(coordinates.coords)); // FIXME use sequence instead? + return (U) this; + } + + protected ElementsImpl(long[] coords) { + if (coords == null || coords.length == 0) { + this.coordinates = new CoordinatesIncrementor(denseArray.shape().asArray(), 0); + } else { + at(coords); + } + } + + protected CoordinatesIncrementor coordinates; + } + + protected final AbstractDenseNdArray denseArray; + + protected > U buffer() { + return (U) denseArray.buffer(); + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydrator.java new file mode 100644 index 0000000..881ea64 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydrator.java @@ -0,0 +1,66 @@ +package org.tensorflow.ndarray.impl.dense.hydrator; + +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; +import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; + +public class DoubleDenseNdArrayHydrator extends DenseNdArrayHydrator implements DoubleNdArrayHydrator { + + public DoubleDenseNdArrayHydrator(DoubleDenseNdArray array) { + super(array); + } + + @Override + public DoubleNdArrayHydrator.Scalars byScalars(long... coordinates) { + return new ScalarsImpl(coordinates); + } + + @Override + public DoubleNdArrayHydrator.Vectors byVectors(long... coordinates) { + return new VectorsImpl(coordinates); + } + + @Override + protected DoubleDataBuffer buffer() { + return super.buffer(); + } + + private class ScalarsImpl extends DenseNdArrayHydrator.ScalarsImpl implements DoubleNdArrayHydrator.Scalars { + + @Override + public DoubleNdArrayHydrator.Scalars at(long... coordinates) { + return super.at(coordinates); + } + + @Override + public DoubleNdArrayHydrator.Scalars put(double scalar) { + buffer().setDouble(scalar, positionIterator.next()); + return this; + } + + private ScalarsImpl(long[] coords) { + super(coords); + } + } + + private class VectorsImpl extends DenseNdArrayHydrator.VectorsImpl implements DoubleNdArrayHydrator.Vectors { + + @Override + public DoubleNdArrayHydrator.Vectors at(long... coordinates) { + return super.at(coordinates); + } + + @Override + public DoubleNdArrayHydrator.Vectors put(double... vector) { + if (vector == null || vector.length > denseArray.shape().get(-1)) { + throw new IllegalArgumentException("Vector should not be null nor exceed " + denseArray.shape().get(-1) + " elements"); + } + buffer().offset(positionIterator.next()).write(vector); + return this; + } + + private VectorsImpl(long[] coords) { + super(coords); + } + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java index 8c9c9f8..a020526 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java @@ -17,9 +17,11 @@ package org.tensorflow.ndarray.impl.sequence; -final class CoordinatesIncrementor { +import java.util.Arrays; - boolean increment() { +public final class CoordinatesIncrementor { + + public boolean increment() { for (int i = coords.length - 1; i >= 0; --i) { if ((coords[i] = (coords[i] + 1) % shape[i]) > 0) { return true; @@ -28,11 +30,19 @@ boolean increment() { return false; } - CoordinatesIncrementor(long[] shape, int dimensionIdx) { + public CoordinatesIncrementor(long[] shape, int dimensionIdx) { this.shape = shape; this.coords = new long[dimensionIdx + 1]; } - final long[] shape; - final long[] coords; + public CoordinatesIncrementor(long[] shape, long[] coords) { + if (coords.length == 0 || coords.length > shape.length) { + throw new IllegalArgumentException(); + } + this.shape = shape; + this.coords = Arrays.copyOf(coords, coords.length); + } + + public final long[] shape; + public final long[] coords; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java index 80b3de6..4bb84ed 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java @@ -43,9 +43,8 @@ private void incrementCoords() { super(dimensions, dimensionIdx); this.shape = dimensions.shape().asArray(); this.coords = new long[dimensionIdx + 1]; - //this.coordsIncrementor = new CoordinatesIncrementor(dimensions.shape().asArray(), dimensionIdx); } private final long[] shape; - private final long[] coords; + private long[] coords; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java index 789474c..dbbd59b 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java @@ -65,6 +65,11 @@ static boolean increment(long[] coords, DimensionalSpace dimensions) { this.coords = new long[dimensionIdx + 1]; } + NdPositionIterator(DimensionalSpace dimensions, long[] coords) { + this.dimensions = dimensions; + this.coords = coords; + } + private final DimensionalSpace dimensions; private long[] coords; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java index 83ed940..3a7bd77 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java @@ -17,6 +17,7 @@ package org.tensorflow.ndarray.impl.sequence; +import java.util.Arrays; import java.util.PrimitiveIterator; import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; @@ -29,6 +30,16 @@ static PositionIterator create(DimensionalSpace dimensions, int dimensionIdx) { return new SequentialPositionIterator(dimensions, dimensionIdx); } + static PositionIterator create(DimensionalSpace dimensions, long... startCoords) { + if (startCoords == null) { + throw new IllegalArgumentException(); + } + if (dimensions.isSegmented()) { + return new NdPositionIterator(dimensions, startCoords); + } + return new SequentialPositionIterator(dimensions, startCoords); + } + static IndexedPositionIterator createIndexed(DimensionalSpace dimensions, int dimensionIdx) { if (dimensions.isSegmented()) { return new NdPositionIterator(dimensions, dimensionIdx); diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java index 65c6fc9..71332ea 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java @@ -17,6 +17,7 @@ package org.tensorflow.ndarray.impl.sequence; +import java.util.Arrays; import java.util.NoSuchElementException; import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; @@ -24,7 +25,7 @@ class SequentialPositionIterator implements PositionIterator { @Override public boolean hasNext() { - return index < end; + return pos < end; } @Override @@ -32,7 +33,7 @@ public long nextLong() { if (!hasNext()) { throw new NoSuchElementException(); } - return stride * index++; + return stride * pos++; } SequentialPositionIterator(DimensionalSpace dimensions, int dimensionIdx) { @@ -42,6 +43,12 @@ public long nextLong() { } this.stride = dimensions.get(dimensionIdx).elementSize(); this.end = size; + this.pos = 0L; + } + + SequentialPositionIterator(DimensionalSpace dimensions, long[] coords) { + this(dimensions, coords.length - 1); + this.pos = dimensions.positionOf(coords) / stride; } SequentialPositionIterator(long stride, long end) { @@ -51,5 +58,5 @@ public long nextLong() { private final long stride; private final long end; - private long index; + private long pos; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java index 8e3892d..edd087f 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java @@ -14,6 +14,11 @@ =======================================================================*/ package org.tensorflow.ndarray.impl.sparse; +import java.nio.ReadOnlyBufferException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.LongStream; import org.tensorflow.ndarray.IllegalRankException; import org.tensorflow.ndarray.LongNdArray; import org.tensorflow.ndarray.NdArray; @@ -30,12 +35,6 @@ import org.tensorflow.ndarray.impl.sequence.SlicingElementSequence; import org.tensorflow.ndarray.index.Index; -import java.nio.ReadOnlyBufferException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.stream.LongStream; - /** * Abstract base class for sparse array. * diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/DoubleSparseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/DoubleSparseNdArray.java index 07a6d2a..4dba5ac 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/DoubleSparseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/DoubleSparseNdArray.java @@ -73,6 +73,24 @@ public class DoubleSparseNdArray extends AbstractSparseNdArray implements DoubleNdArrayHydrator { + + public DoubleSparseNdArrayHydrator(DoubleSparseNdArray array) { + super(array); + } + + @Override + public DoubleNdArrayHydrator.Scalars byScalars(long... coordinates) { + return new ScalarsImpl(coordinates); + } + + @Override + public DoubleNdArrayHydrator.Vectors byVectors(long... coordinates) { + return new VectorsImpl(coordinates); + } + + private class ScalarsImpl extends SparseNdArrayHydrator.ScalarsImpl implements DoubleNdArrayHydrator.Scalars { + + @Override + public DoubleNdArrayHydrator.Scalars at(long... coordinates) { + return super.at(coordinates); + } + + @Override + public DoubleNdArrayHydrator.Scalars put(double scalar) { + sparseArray().getValues().setDouble(scalar, index); + sparseArray().getIndices().set(NdArrays.vectorOf(coordinates.coords), index++); + coordinates.increment(); + return this; + } + + private ScalarsImpl(long[] coords) { + super(coords); + } + } + + private class VectorsImpl extends SparseNdArrayHydrator.VectorsImpl implements DoubleNdArrayHydrator.Vectors { + + @Override + public DoubleNdArrayHydrator.Vectors at(long... coordinates) { + return super.at(coordinates); + } + + @Override + public DoubleNdArrayHydrator.Vectors put(double... vector) { + if (vector == null || vector.length > sparseArray().shape().get(-1)) { + throw new IllegalArgumentException("Vector should not be null nor exceed " + sparseArray().shape().get(-1) + " elements"); + } + double defaultValue = sparseArray().getDefaultValue(); + for (double value : vector) { + if (value != defaultValue) { + sparseArray().getValues().setDouble(value, index); + sparseArray().getIndices().set(NdArrays.vectorOf(coordinates.coords), index++); + } + coordinates.increment(); + } + return this; + } + + private VectorsImpl(long[] coords) { + super(coords); + } + } + + @Override + protected DoubleSparseNdArray sparseArray() { + return super.sparseArray(); + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydrator.java new file mode 100644 index 0000000..aaf25df --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydrator.java @@ -0,0 +1,146 @@ +package org.tensorflow.ndarray.impl.sparse.hydrator; + +import java.util.Arrays; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.hydrator.NdArrayHydrator; +import org.tensorflow.ndarray.impl.sequence.CoordinatesIncrementor; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; + +class SparseNdArrayHydrator implements NdArrayHydrator { + + public SparseNdArrayHydrator(AbstractSparseNdArray array) { + this.sparseArray = array; + } + + @Override + public Scalars byScalars(long... coordinates) { + return new ScalarsImpl(coordinates); + } + + @Override + public Vectors byVectors(long... coordinates) { + return new VectorsImpl(coordinates); + } + + @Override + public Elements byElements(long... coordinates) { + return new ElementsImpl(coordinates); + } + + protected class ScalarsImpl implements Scalars { + + @Override + public > U at(long... coordinates) { + if (coordinates == null || coordinates.length != sparseArray.shape().numDimensions()) { + throw new IllegalArgumentException(Arrays.toString(coordinates) + " are not valid scalar coordinates for an array of shape " + sparseArray + .shape()); + } + this.coordinates = new CoordinatesIncrementor(sparseArray.shape().asArray(), coordinates); + return (U) this; + } + + @Override + public > U putObject(T scalar) { + sparseArray.getValues().setObject(scalar, index); + sparseArray.getIndices().set(NdArrays.vectorOf(coordinates.coords), index++); + coordinates.increment(); + return (U) this; + } + + protected ScalarsImpl(long[] coords) { + if (coords == null || coords.length == 0) { + coordinates = new CoordinatesIncrementor(sparseArray.shape().asArray(), sparseArray.shape().numDimensions() - 1); + } else { + at(coords); + } + } + + protected CoordinatesIncrementor coordinates; + } + + protected class VectorsImpl implements Vectors { + + @Override + public > U at(long... coordinates) { + if (coordinates == null || coordinates.length != sparseArray.shape().numDimensions() - 1) { + throw new IllegalArgumentException(Arrays.toString(coordinates) + " are not valid vector coordinates for an array of shape " + sparseArray + .shape()); + } + this.coordinates = new CoordinatesIncrementor(sparseArray.shape().asArray(), Arrays.copyOf(coordinates, sparseArray.shape().numDimensions())); + return (U) this; + } + + @Override + public > U putObjects(T... vector) { + if (vector == null || vector.length > sparseArray.shape().get(-1)) { + throw new IllegalArgumentException("Vector should not be null nor exceed " + sparseArray.shape().get(-1) + " elements"); + } + for (T value : vector) { + if (value != sparseArray.getDefaultValue()) { + sparseArray.getValues().setObject(value, index); + sparseArray.getIndices().set(NdArrays.vectorOf(coordinates.coords), index++); + } + coordinates.increment(); + } + return (U) this; + } + + protected VectorsImpl(long[] coords) { + if (sparseArray.shape().numDimensions() < 1) { + throw new IllegalArgumentException("Cannot hydrate a scalar with vectors"); + } + if (coords == null || coords.length == 0) { + coordinates = new CoordinatesIncrementor(sparseArray.shape().asArray(), sparseArray.shape().numDimensions() - 1); + } else { + at(coords); + } + } + + protected CoordinatesIncrementor coordinates; + } + + protected class ElementsImpl implements Elements { + + @Override + public > U at(long... coordinates) { + if (coordinates == null || coordinates.length == 0 || coordinates.length > sparseArray.shape().numDimensions()) { + throw new IllegalArgumentException(Arrays.toString(coordinates) + " are not valid coordinates for an array of shape " + sparseArray + .shape()); + } + this.coordinates = new CoordinatesIncrementor(sparseArray.shape().asArray(), Arrays.copyOf(coordinates, sparseArray.shape().numDimensions())); + return (U) this; + } + + @Override + public > U put(NdArray array) { + array.scalars().forEach(s -> { + T value = s.getObject(); + if (value != sparseArray.getDefaultValue()) { + sparseArray.getValues().setObject(value, index); + sparseArray.getIndices().set(NdArrays.vectorOf(coordinates.coords), index++); + } + coordinates.increment(); + }); + return (U) this; + } + + protected ElementsImpl(long[] coords) { + if (coords == null || coords.length == 0) { + this.coordinates = new CoordinatesIncrementor(sparseArray.shape().asArray(), sparseArray.shape().numDimensions() - 1); + } else { + at(coords); + } + } + + protected CoordinatesIncrementor coordinates; + } + + protected long index = 0; + + protected > U sparseArray() { + return (U) sparseArray; + } + + private final AbstractSparseNdArray sparseArray; +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydratorTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydratorTest.java new file mode 100644 index 0000000..ee4df4a --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydratorTest.java @@ -0,0 +1,87 @@ +package org.tensorflow.ndarray.impl.dense.hydrator; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; +import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; + +public class DenseNdArrayHydratorTest { + + @Test + public void hydrateNdArrayByScalars() { + DoubleNdArray array = NdArrays.ofDoubles(Shape.of(3, 2, 3), hydrator -> { + hydrator + .byScalars() + .put(0.0) + .put(0.1) + .put(0.2) + .put(0.3) + .put(0.4) + .put(0.5) + .put(1.0) + .put(1.1) + .put(1.2) + .at(2, 0, 0) + .put(2.0) + .put(2.1) + .put(2.2) + .put(2.3) + .put(2.4) + .put(2.5); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][] { + {{ 0.0, 0.1, 0.2 }, { 0.3, 0.4, 0.5 }}, + {{ 1.0, 1.1, 1.2 }, { 0.0, 0.0, 0.0 }}, + {{ 2.0, 2.1, 2.2 }, { 2.3, 2.4, 2.5 }} + }), array); + } + + @Test + public void hydrateNdArrayByVectors() { + DoubleNdArray array = NdArrays.ofDoubles(Shape.of(3, 2, 3), hydrator -> { + hydrator.byVectors() + .put(0.0, 0.1, 0.2) + .put(0.3, 0.4, 0.5) + .put(1.0, 1.1, 1.2) + .at(2, 0) + .put(2.0, 2.1, 2.2) + .put(2.3, 2.4, 2.5); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][] { + {{ 0.0, 0.1, 0.2 }, { 0.3, 0.4, 0.5 }}, + {{ 1.0, 1.1, 1.2 }, { 0.0, 0.0, 0.0 }}, + {{ 2.0, 2.1, 2.2 }, { 2.3, 2.4, 2.5 }} + }), array); + } + + @Test + public void hydrateNdArrayByElements() { + DoubleNdArray array = NdArrays.ofDoubles(Shape.of(3, 2, 3), hydrator -> { + hydrator.byElements() + .put(StdArrays.ndCopyOf(new double[][] { + { 0.0, 0.1, 0.2 }, + { 0.3, 0.4, 0.5 } + })) + .at(1, 0) + .put(NdArrays.vectorOf(1.0, 1.1, 1.2)) + .at(2) + .put(StdArrays.ndCopyOf(new double[][] { + { 2.0, 2.1, 2.2 }, + { 2.3, 2.4, 2.5 } + })); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][] { + {{ 0.0, 0.1, 0.2 }, { 0.3, 0.4, 0.5 }}, + {{ 1.0, 1.1, 1.2 }, { 0.0, 0.0, 0.0 }}, + {{ 2.0, 2.1, 2.2 }, { 2.3, 2.4, 2.5 }} + }), array); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydratorTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydratorTest.java new file mode 100644 index 0000000..8ae36f5 --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydratorTest.java @@ -0,0 +1,85 @@ +package org.tensorflow.ndarray.impl.sparse.hydrator; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; + +public class SparseNdArrayHydratorTest { + + @Test + public void hydrateNdArrayByScalars() { + DoubleNdArray array = NdArrays.sparseOfDoubles(15, Shape.of(3, 2, 3), hydrator -> { + hydrator + .byScalars() + .put(0.0) + .put(0.1) + .put(0.2) + .put(0.3) + .put(0.4) + .put(0.5) + .put(1.0) + .put(1.1) + .put(1.2) + .at(2, 0, 0) + .put(2.0) + .put(2.1) + .put(2.2) + .put(2.3) + .put(2.4) + .put(2.5); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][] { + {{ 0.0, 0.1, 0.2 }, { 0.3, 0.4, 0.5 }}, + {{ 1.0, 1.1, 1.2 }, { 0.0, 0.0, 0.0 }}, + {{ 2.0, 2.1, 2.2 }, { 2.3, 2.4, 2.5 }} + }), array); + } + + @Test + public void hydrateNdArrayByVectors() { + DoubleNdArray array = NdArrays.sparseOfDoubles(15, Shape.of(3, 2, 3), hydrator -> { + hydrator.byVectors() + .put(0.0, 0.1, 0.2) + .put(0.3, 0.4, 0.5) + .put(1.0, 1.1, 1.2) + .at(2, 0) + .put(2.0, 2.1, 2.2) + .put(2.3, 2.4, 2.5); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][] { + {{ 0.0, 0.1, 0.2 }, { 0.3, 0.4, 0.5 }}, + {{ 1.0, 1.1, 1.2 }, { 0.0, 0.0, 0.0 }}, + {{ 2.0, 2.1, 2.2 }, { 2.3, 2.4, 2.5 }} + }), array); + } + + @Test + public void hydrateNdArrayByElements() { + DoubleNdArray array = NdArrays.sparseOfDoubles(15, Shape.of(3, 2, 3), hydrator -> { + hydrator.byElements() + .put(StdArrays.ndCopyOf(new double[][] { + { 0.0, 0.1, 0.2 }, + { 0.3, 0.4, 0.5 } + })) + .at(1, 0) + .put(NdArrays.vectorOf(1.0, 1.1, 1.2)) + .at(2) + .put(StdArrays.ndCopyOf(new double[][] { + { 2.0, 2.1, 2.2 }, + { 2.3, 2.4, 2.5 } + })); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][] { + {{ 0.0, 0.1, 0.2 }, { 0.3, 0.4, 0.5 }}, + {{ 1.0, 1.1, 1.2 }, { 0.0, 0.0, 0.0 }}, + {{ 2.0, 2.1, 2.2 }, { 2.3, 2.4, 2.5 }} + }), array); + } +} From 0e73be00358d0b3260382606aba0e5419d094587 Mon Sep 17 00:00:00 2001 From: karllessard Date: Sun, 4 Dec 2022 22:50:19 -0500 Subject: [PATCH 2/3] Temp work for hydrators --- .../java/org/tensorflow/ndarray/NdArrays.java | 51 +++++- .../hydrator/DoubleNdArrayHydrator.java | 160 ++++++++++++++++- .../ndarray/hydrator/NdArrayHydrator.java | 164 ++++++++++++++++-- .../impl/dense/AbstractDenseNdArray.java | 22 ++- .../dense/hydrator/DenseNdArrayHydrator.java | 115 +++++------- .../hydrator/DoubleDenseNdArrayHydrator.java | 84 ++++++--- .../ndarray/impl/dense/hydrator/Helpers.java | 52 ++++++ .../impl/sequence/FastElementSequence.java | 14 +- .../IndexedSequentialPositionIterator.java | 33 ++-- .../impl/sequence/NdPositionIterator.java | 23 +-- .../impl/sequence/PositionIterator.java | 10 +- .../impl/sequence/SlicingElementSequence.java | 18 +- .../hydrator/DoubleSparseNdArrayHydrator.java | 113 ++++++++---- .../ndarray/impl/sparse/hydrator/Helpers.java | 49 ++++++ .../hydrator/SparseNdArrayHydrator.java | 132 ++++++-------- .../DoubleNdArrayHydratorTestBase.java | 141 +++++++++++++++ .../hydrator/DenseNdArrayHydratorTest.java | 87 ---------- .../DoubleDenseNdArrayHydratorTest.java | 24 +++ .../impl/sequence/ElementSequenceTest.java | 2 +- .../DoubleSparseNdArrayHydratorTest.java | 23 +++ 20 files changed, 949 insertions(+), 368 deletions(-) create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/Helpers.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/Helpers.java create mode 100644 ndarray/src/test/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydratorTestBase.java delete mode 100644 ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydratorTest.java create mode 100644 ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydratorTest.java create mode 100644 ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydratorTest.java diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java b/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java index bf75e8f..bfbe939 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java @@ -27,6 +27,8 @@ import org.tensorflow.ndarray.buffer.LongDataBuffer; import org.tensorflow.ndarray.buffer.ShortDataBuffer; import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; +import org.tensorflow.ndarray.hydrator.NdArrayHydrator; +import org.tensorflow.ndarray.impl.dense.AbstractDenseNdArray; import org.tensorflow.ndarray.impl.dense.BooleanDenseNdArray; import org.tensorflow.ndarray.impl.dense.ByteDenseNdArray; import org.tensorflow.ndarray.impl.dense.DenseNdArray; @@ -35,8 +37,10 @@ import org.tensorflow.ndarray.impl.dense.IntDenseNdArray; import org.tensorflow.ndarray.impl.dense.LongDenseNdArray; import org.tensorflow.ndarray.impl.dense.ShortDenseNdArray; +import org.tensorflow.ndarray.impl.dense.hydrator.DenseNdArrayHydrator; import org.tensorflow.ndarray.impl.dense.hydrator.DoubleDenseNdArrayHydrator; import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; import org.tensorflow.ndarray.impl.sparse.BooleanSparseNdArray; import org.tensorflow.ndarray.impl.sparse.ByteSparseNdArray; import org.tensorflow.ndarray.impl.sparse.DoubleSparseNdArray; @@ -45,6 +49,7 @@ import org.tensorflow.ndarray.impl.sparse.LongSparseNdArray; import org.tensorflow.ndarray.impl.sparse.ShortSparseNdArray; import org.tensorflow.ndarray.impl.sparse.hydrator.DoubleSparseNdArrayHydrator; +import org.tensorflow.ndarray.impl.sparse.hydrator.SparseNdArrayHydrator; /** Utility class for instantiating {@link NdArray} objects. */ public final class NdArrays { @@ -560,7 +565,7 @@ public static DoubleNdArray ofDoubles(Shape shape) { } /** - * Creates an N-dimensional array of doubles of the given shape, with data hydration + * Creates an N-dimensional array of doubles of the given shape, hydrating it with data after its allocation * * @param shape shape of the array * @param hydrate initialize the data of the created array, using a hydrator @@ -586,7 +591,16 @@ public static DoubleNdArray wrap(Shape shape, DoubleDataBuffer buffer) { return DoubleDenseNdArray.create(buffer, shape); } - public static DoubleSparseNdArray sparseOfDoubles(long numValues, Shape shape, Consumer hydrate) { + /** + * Creates an Sparse array of doubles of the given shape, hydrating it with data after its allocation + * + * @param shape shape of the array + * @param numValues number of double value actually set in the array, others defaulting to the zero value + * @param hydrate initialize the data of the created array, using a hydrator + * @return new double N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static DoubleSparseNdArray sparseOfDoubles(Shape shape, long numValues, Consumer hydrate) { LongNdArray indices = ofLongs(Shape.of(numValues, shape.numDimensions())); DoubleNdArray values = ofDoubles(Shape.of(numValues)); DoubleSparseNdArray array = DoubleSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); @@ -782,6 +796,21 @@ public static NdArray ofObjects(Class clazz, Shape shape) { return wrap(shape, DataBuffers.ofObjects(clazz, shape.size())); } + /** + * Creates an N-dimensional array of objects of the given shape, hydrating it with data after its allocation + * + * @param clazz class of the data to be stored in this array + * @param shape shape of the array + * @param hydrate initialize the data of the created array, using a hydrator + * @return new N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static NdArray ofObjects(Class clazz, Shape shape, Consumer> hydrate) { + AbstractDenseNdArray array = (AbstractDenseNdArray)ofObjects(clazz, shape); + hydrate.accept(new DenseNdArrayHydrator(array)); + return array; + } + /** * Wraps a buffer in an N-dimensional array of a given shape. * @@ -796,6 +825,24 @@ public static NdArray wrap(Shape shape, DataBuffer buffer) { return DenseNdArray.wrap(buffer, shape); } + /** + * Creates an Sparse array of objects of the given shape, hydrating it with data after its allocation + * + * @param type the class type represented by this sparse array. + * @param shape shape of the array + * @param numValues number of values actually set in the array, others defaulting to the zero value + * @param hydrate initialize the data of the created array, using a hydrator + * @return new N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static NdArray sparseOfObjects(Class type, Shape shape, long numValues, Consumer> hydrate) { + LongNdArray indices = ofLongs(Shape.of(numValues, shape.numDimensions())); + NdArray values = ofObjects(type, Shape.of(numValues)); + AbstractSparseNdArray array = (AbstractSparseNdArray)sparseOfObjects(type, indices, values, shape); + hydrate.accept(new SparseNdArrayHydrator(array)); + return array; + } + /** * Creates a Sparse array of values with a null default value * diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydrator.java index cac8472..9440e42 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydrator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydrator.java @@ -1,30 +1,172 @@ package org.tensorflow.ndarray.hydrator; import org.tensorflow.ndarray.DoubleNdArray; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.Shape; -public interface DoubleNdArrayHydrator extends NdArrayHydrator { +/** + * Specialization of the {@link NdArrayHydrator} API for hydrating arrays of doubles. + * + * @see NdArrayHydrator + */ +public interface DoubleNdArrayHydrator { - interface Scalars extends NdArrayHydrator.Scalars { + /** + * An API for hydrate an {@link DoubleNdArray} using scalar values + */ + interface Scalars { - @Override + /** + * Position the hydrator to the given {@code coordinates} to write the next scalars. + * + * @param coordinates position in the hydrated array + * @return this API + * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of a scalar + */ Scalars at(long... coordinates); + /** + * Set a double value as the next scalar value in the hydrated array. + * + * @param scalar next scalar value + * @return this API + * @throws IllegalArgumentException if {@code scalar} is null + */ Scalars put(double scalar); } - interface Vectors extends NdArrayHydrator.Vectors { + /** + * An API for hydrate an {@link DoubleNdArray} using vectors, i.e. a list of scalars + */ + interface Vectors { - @Override + /** + * Position the hydrator to the given {@code coordinates} to write the next vectors. + * + * @param coordinates position in the hydrated array + * @return this API + * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of a vector + */ Vectors at(long... coordinates); + /** + * Set a list of doubles as the next vector in the hydrated array. + * + * @param vector next vector values + * @return this API + * @throws IllegalArgumentException if {@code vector} is empty or its length is greater than the size of the dimension + * {@code n-1}, given {@code n} the rank of the hydrated array + */ Vectors put(double... vector); } - @Override + /** + * An API for hydrate an {@link DoubleNdArray} using n-dimensional elements (sub-arrays). + */ + interface Elements { + + /** + * Position the hydrator to the given {@code coordinates} to write the next elements. + * + * @param coordinates position in the hydrated array + * @return this API + * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of an element of the hydrated array + */ + Elements at(long... coordinates); + + /** + * Set a n-dimensional array of doubles as the next element in the hydrated array. + * + * @param element array containing the next element values + * @return this API + * @throws IllegalArgumentException if {@code element} is null or its shape is incompatible with the current hydrator position + */ + Elements put(DoubleNdArray element); + } + + /** + * Start to hydrate the targeted array with scalars. + * + * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, + * defaults to the first scalar of this array. + * + * Example of usage: + *
{@code
+   *    DoubleNdArray array = NdArrays.ofDoubles(Shape.of(3, 2), hydrator -> {
+   *        hydrator.byScalars()
+   *          .put(10.0)
+   *          .put(20.0)
+   *          .put(30.0)
+   *          .at(2, 1)
+   *          .put(40.0);
+   *    });
+   *    // -> [[10.0, 20.0], [30.0, 0.0], [0.0, 40.0]]
+   * }
+ * + * @param coordinates position in the hydrated array to start from + * @return a {@link Scalars} instance + * @throws IllegalArgumentException if {@code coordinates} are set but are not one of a scalar + */ Scalars byScalars(long... coordinates); - @Override + /** + * Start to hydrate the targeted array with vectors. + * + * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, + * defaults to the first scalar of the first vector of this array. + * + * Example of usage: + *
{@code
+   *    DoubleNdArray array = NdArrays.ofDoubles(Shape.of(3, 2), hydrator -> {
+   *        hydrator.byVectors()
+   *          .put(10.0, 20.0)
+   *          .put(30.0)
+   *          .at(2)
+   *          .put(40.0, 50.0);
+   *    });
+   *    // -> [[10.0, 20.0], [30.0, null], [40.0, 50.0]]
+   * }
+ * + * @param coordinates position in the hydrated array to start from + * @return a {@link Vectors} instance + * @throws IllegalArgumentException if hydrated array is of rank-0 or if {@code coordinates} are set but are not one of a vector + */ Vectors byVectors(long... coordinates); + + /** + * Start to hydrate the targeted array with n-dimensional elements. + * + * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, + * defaults to the first element in the first (0) dimension of the hydrated array. + * + * Example of usage: + *
{@code
+   *    DoubleNdArray vector = NdArrays.vectorOf(10.0, 20.0);
+   *    DoubleNdArray scalar = NdArrays.scalarOf(30.0);
+   *
+   *    DoubleNdArray array = NdArrays.ofDoubles(Shape.of(4, 2), hydrator -> {
+   *        hydrator.byElements()
+   *          .put(vector)
+   *          .put(vector)
+   *          .at(2, 1)
+   *          .put(scalar)
+   *          .at(3)
+   *          .put(vector);
+   *    });
+   *    // -> [[10.0, 20.0], [10.0, 20.0], [0.0, 30.0], [10.0, 20.0]]
+   * }
+ * + * @param coordinates position in the hydrated array to start from + * @return a {@link Elements} instance + * @throws IllegalArgumentException if {@code coordinates} are set but are not one of an element of the hydrated array + */ + Elements byElements(long... coordinates); + + /** + * Creates an API to hydrate the targeted array with {@code Double} boxed type. + * + * Note that sticking to primitive types improve I/O performances overall, so only rely boxed types if the data is already + * available in that format. + * + * @return a hydrator supporting {@code Double} boxed type + */ + NdArrayHydrator boxed(); } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/NdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/NdArrayHydrator.java index 6ecb398..fc64f89 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/NdArrayHydrator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/NdArrayHydrator.java @@ -1,35 +1,177 @@ package org.tensorflow.ndarray.hydrator; import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +/** + * Interface for initializing the data of a {@link NdArray} that has just been allocated. + * + * While it is always possible to set the data of a read-write NdArray using standard output methods, + * like {@link NdArray#write(DataBuffer)} or {@link NdArray#copyTo(NdArray)}, the hydrator API focuses on + * sequential per-element initialization, similar to standard Java arrays. + * + * Since the hydrator API is only accessible right after the array have been allocated, it can be used to + * initialize data-sensitive arrays, like {@link org.tensorflow.ndarray.SparseNdArray}, which can be only + * written once and stay read-only thereafter. + * + * @param the type of data of the {@link NdArray} to initialize + */ public interface NdArrayHydrator { + /** + * An API for hydrate an {@link NdArray} using scalar values + * + * @param the type of data of the {@link NdArray} to initialize + */ interface Scalars { - > U at(long... coordinates); + /** + * Position the hydrator to the given {@code coordinates} to write the next scalars. + * + * @param coordinates position in the hydrated array + * @return this API + * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of a scalar + */ + Scalars at(long... coordinates); - > U putObject(T scalar); + /** + * Set an object as the next scalar value in the hydrated array. + * + * @param scalar next scalar value + * @return this API + * @throws IllegalArgumentException if {@code scalar} is null + */ + Scalars put(T scalar); } + /** + * An API for hydrate an {@link NdArray} using vectors, i.e. a list of scalars + * + * @param the type of data of the {@link NdArray} to initialize + */ interface Vectors { - > U at(long... coordinates); + /** + * Position the hydrator to the given {@code coordinates} to write the next vectors. + * + * @param coordinates position in the hydrated array + * @return this API + * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of a vector + */ + Vectors at(long... coordinates); - > U putObjects(T... vector); + /** + * Set a list of objects as the next vector in the hydrated array. + * + * @param vector next vector values + * @return this API + * @throws IllegalArgumentException if {@code vector} is empty or its length is greater than the size of the dimension + * {@code n-1}, given {@code n} the rank of the hydrated array + */ + Vectors put(T... vector); } + /** + * An API for hydrate an {@link NdArray} using n-dimensional elements (sub-arrays). + * + * @param the type of data of the {@link NdArray} to initialize + */ interface Elements { - > U at(long... coordinates); + /** + * Position the hydrator to the given {@code coordinates} to write the next elements. + * + * @param coordinates position in the hydrated array + * @return this API + * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of an element of the hydrated array + */ + Elements at(long... coordinates); - > U put(NdArray vector); + /** + * Set a n-dimensional array of objects as the next element in the hydrated array. + * + * @param element array containing the next element values + * @return this API + * @throws IllegalArgumentException if {@code element} is null or its shape is incompatible with the current hydrator position + */ + Elements put(NdArray element); } - > U byScalars(long... coordinates); + /** + * Start to hydrate the targeted array with scalars. + * + * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, + * defaults to the first scalar of this array. + * + * Example of usage: + *
{@code
+   *    NdArray array = NdArrays.ofObjects(String.class, Shape.of(3, 2), hydrator -> {
+   *        hydrator.byScalars()
+   *          .put("Cat")
+   *          .put("Dog")
+   *          .put("House")
+   *          .at(2, 1)
+   *          .put("Apple");
+   *    });
+   *    // -> [["Cat", "Dog"], ["House", null], [null, "Apple"]]
+   * }
+ * + * @param coordinates position in the hydrated array to start from + * @return a {@link Scalars} instance + * @throws IllegalArgumentException if {@code coordinates} are set but are not one of a scalar + */ + Scalars byScalars(long... coordinates); - > U byVectors(long... coordinates); + /** + * Start to hydrate the targeted array with vectors. + * + * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, + * defaults to the first scalar of the first vector of this array. + * + * Example of usage: + *
{@code
+   *    NdArray array = NdArrays.ofObjects(String.class, Shape.of(3, 2), hydrator -> {
+   *        hydrator.byVectors()
+   *          .put("Cat", "Dog")
+   *          .put("House")
+   *          .at(2)
+   *          .put("Orange", "Apple");
+   *    });
+   *    // -> [["Cat", "Dog"], ["House", null], ["Orange", "Apple"]]
+   * }
+ * + * @param coordinates position in the hydrated array to start from + * @return a {@link Vectors} instance + * @throws IllegalArgumentException if hydrated array is of rank-0 or if {@code coordinates} are set but are not one of a vector + */ + Vectors byVectors(long... coordinates); - > U byElements(long... coordinates); + /** + * Start to hydrate the targeted array with n-dimensional elements. + * + * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, + * defaults to the first element in the first (0) dimension of the hydrated array. + * + * Example of usage: + *
{@code
+   *    NdArray vector = NdArrays.vectorOfObjects("Cat", "Dog");
+   *    NdArray scalar = NdArrays.scalarOfObject("Apple");
+   *
+   *    NdArray array = NdArrays.ofObjects(String.class, Shape.of(4, 2), hydrator -> {
+   *        hydrator.byElements()
+   *          .put(vector)
+   *          .put(vector)
+   *          .at(2, 1)
+   *          .put(scalar)
+   *          .at(3)
+   *          .put(vector);
+   *    });
+   *    // -> [["Cat", "Dog"], ["Cat", "Dog"], [null, "Apple"], ["Cat", "Dog"]]
+   * }
+ * + * @param coordinates position in the hydrated array to start from + * @return a {@link Elements} instance + * @throws IllegalArgumentException if {@code coordinates} are set but are not one of an element of the hydrated array + */ + Elements byElements(long... coordinates); } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java index 9fc353c..9ef3749 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java @@ -33,6 +33,18 @@ public abstract class AbstractDenseNdArray> extends Abst abstract public DataBuffer buffer(); + public NdArraySequence elementsAt(long[] startCoords) { + DimensionalSpace elemDims = dimensions().from(startCoords.length); + try { + DataBufferWindow> elemWindow = buffer().window(elemDims.physicalSize()); + U element = instantiate(elemWindow.buffer(), elemDims); + return new FastElementSequence(this, startCoords, element, elemWindow); + } catch (UnsupportedOperationException e) { + // If buffer windows are not supported, fallback to slicing (and slower) sequence + return new SlicingElementSequence(this, startCoords, elemDims); + } + } + @Override public NdArraySequence elements(int dimensionIdx) { if (dimensionIdx >= shape().numDimensions()) { @@ -42,15 +54,7 @@ public NdArraySequence elements(int dimensionIdx) { if (rank() == 0 && dimensionIdx < 0) { return new SingleElementSequence<>(this); } - DimensionalSpace elemDims = dimensions().from(dimensionIdx + 1); - try { - DataBufferWindow> elemWindow = buffer().window(elemDims.physicalSize()); - U element = instantiate(elemWindow.buffer(), elemDims); - return new FastElementSequence(this, dimensionIdx, element, elemWindow); - } catch (UnsupportedOperationException e) { - // If buffer windows are not supported, fallback to slicing (and slower) sequence - return new SlicingElementSequence<>(this, dimensionIdx, elemDims); - } + return elementsAt(new long[dimensionIdx + 1]); } @Override diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydrator.java index f764b91..d7bced0 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydrator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydrator.java @@ -1,17 +1,16 @@ package org.tensorflow.ndarray.impl.dense.hydrator; -import java.util.Arrays; +import java.util.Iterator; + import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.buffer.DataBuffer; import org.tensorflow.ndarray.hydrator.NdArrayHydrator; import org.tensorflow.ndarray.impl.dense.AbstractDenseNdArray; -import org.tensorflow.ndarray.impl.sequence.CoordinatesIncrementor; import org.tensorflow.ndarray.impl.sequence.PositionIterator; -class DenseNdArrayHydrator implements NdArrayHydrator { +public class DenseNdArrayHydrator implements NdArrayHydrator { - public DenseNdArrayHydrator(AbstractDenseNdArray array) { - this.denseArray = array; + public DenseNdArrayHydrator(AbstractDenseNdArray> array) { + this.array = array; } @Override @@ -29,102 +28,74 @@ public Elements byElements(long... coordinates) { return new ElementsImpl(coordinates); } - protected class ScalarsImpl implements Scalars { + class ScalarsImpl implements Scalars { - @Override - public > U at(long... coordinates) { - if (coordinates == null || coordinates.length != denseArray.shape().numDimensions()) { - throw new IllegalArgumentException(Arrays.toString(coordinates) + " are not valid scalar coordinates for an array of shape " + denseArray - .shape()); - } - positionIterator = PositionIterator.create(denseArray.dimensions(), coordinates); - return (U) this; + public Scalars at(long... coordinates) { + positionIterator = Helpers.iterateByPosition(array, 0, coordinates); + return this; } @Override - public > U putObject(T scalar) { - buffer().setObject(scalar, positionIterator.next()); - return (U) this; + public Scalars put(T scalar) { + if (scalar == null) { + throw new IllegalArgumentException("Scalar value cannot be null"); + } + array.buffer().setObject(scalar, positionIterator.nextLong()); + return this; } - protected ScalarsImpl(long[] coords) { - if (coords == null || coords.length == 0) { - positionIterator = PositionIterator.create(denseArray.dimensions(), denseArray.shape().numDimensions() - 1); - } else { - at(coords); - } + ScalarsImpl(long[] coordinates) { + positionIterator = Helpers.iterateByPosition(array, 0, coordinates); } - protected PositionIterator positionIterator; + private PositionIterator positionIterator; } - protected class VectorsImpl implements Vectors { + class VectorsImpl implements Vectors { @Override - public > U at(long... coordinates) { - if (coordinates == null || coordinates.length != denseArray.shape().numDimensions() - 1) { - throw new IllegalArgumentException(Arrays.toString(coordinates) + " are not valid vector coordinates for an array of shape " + denseArray - .shape()); - } - positionIterator = PositionIterator.create(denseArray.dimensions(), coordinates); - return (U) this; + public Vectors at(long... coordinates) { + positionIterator = Helpers.iterateByPosition(array, 1, coordinates); + return this; } @Override - public > U putObjects(T... vector) { - if (vector == null || vector.length > denseArray.shape().get(-1)) { - throw new IllegalArgumentException("Vector should not be null nor exceed " + denseArray.shape().get(-1) + " elements"); - } - buffer().offset(positionIterator.next()).write(vector); - return (U) this; + public Vectors put(T... vector) { + Helpers.validateVectorLength(vector.length, array.shape()); + array.buffer().offset(positionIterator.nextLong()).write(vector); + return this; } - protected VectorsImpl(long[] coords) { - if (denseArray.shape().numDimensions() < 1) { - throw new IllegalArgumentException("Cannot hydrate a scalar with vectors"); - } - if (coords == null || coords.length == 0) { - positionIterator = PositionIterator.create(denseArray.dimensions(), denseArray.shape().numDimensions() - 2); - } else { - at(coords); - } + VectorsImpl(long[] coordinates) { + positionIterator = Helpers.iterateByPosition(array, 1, coordinates); } - protected PositionIterator positionIterator; + private PositionIterator positionIterator; } - protected class ElementsImpl implements Elements { + class ElementsImpl implements Elements { @Override - public > U at(long... coordinates) { - if (coordinates == null || coordinates.length == 0 || coordinates.length > denseArray.shape().numDimensions()) { - throw new IllegalArgumentException(Arrays.toString(coordinates) + " are not valid coordinates for an array of shape " + denseArray - .shape()); - } - this.coordinates = new CoordinatesIncrementor(denseArray.shape().asArray(), coordinates); - return (U) this; + public Elements at(long... coordinates) { + this.elementIterator = Helpers.iterateByElement(array, coordinates); + return this; } @Override - public > U put(NdArray array) { - array.copyTo(denseArray.get(coordinates.coords)); // FIXME use sequence instead? - return (U) this; + public Elements put(NdArray element) { + if (element == null) { + throw new IllegalArgumentException("Element cannot be null"); + } + element.copyTo(elementIterator.next()); + return this; } - protected ElementsImpl(long[] coords) { - if (coords == null || coords.length == 0) { - this.coordinates = new CoordinatesIncrementor(denseArray.shape().asArray(), 0); - } else { - at(coords); - } + ElementsImpl(long[] coordinates) { + this.elementIterator = Helpers.iterateByElement(array, coordinates); } - protected CoordinatesIncrementor coordinates; + private Iterator> elementIterator; } - protected final AbstractDenseNdArray denseArray; - - protected > U buffer() { - return (U) denseArray.buffer(); - } + private final AbstractDenseNdArray> array; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydrator.java index 881ea64..087bb6f 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydrator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydrator.java @@ -1,66 +1,104 @@ package org.tensorflow.ndarray.impl.dense.hydrator; -import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import java.util.Iterator; + +import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; +import org.tensorflow.ndarray.hydrator.NdArrayHydrator; import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; +import org.tensorflow.ndarray.impl.sequence.PositionIterator; -public class DoubleDenseNdArrayHydrator extends DenseNdArrayHydrator implements DoubleNdArrayHydrator { +public class DoubleDenseNdArrayHydrator implements DoubleNdArrayHydrator { public DoubleDenseNdArrayHydrator(DoubleDenseNdArray array) { - super(array); + this.array = array; } @Override - public DoubleNdArrayHydrator.Scalars byScalars(long... coordinates) { + public Scalars byScalars(long... coordinates) { return new ScalarsImpl(coordinates); } @Override - public DoubleNdArrayHydrator.Vectors byVectors(long... coordinates) { + public Vectors byVectors(long... coordinates) { return new VectorsImpl(coordinates); } @Override - protected DoubleDataBuffer buffer() { - return super.buffer(); + public Elements byElements(long... coordinates) { + return new ElementsImpl(coordinates); + } + + @Override + public NdArrayHydrator boxed() { + return new DenseNdArrayHydrator(array); } - private class ScalarsImpl extends DenseNdArrayHydrator.ScalarsImpl implements DoubleNdArrayHydrator.Scalars { + class ScalarsImpl implements Scalars { + + public Scalars at(long... coordinates) { + positionIterator = Helpers.iterateByPosition(array, 0, coordinates); + return this; + } @Override - public DoubleNdArrayHydrator.Scalars at(long... coordinates) { - return super.at(coordinates); + public Scalars put(double scalar) { + array.buffer().setObject(scalar, positionIterator.nextLong()); + return this; + } + + ScalarsImpl(long[] coordinates) { + positionIterator = Helpers.iterateByPosition(array, 0, coordinates); + } + + private PositionIterator positionIterator; + } + + class VectorsImpl implements Vectors { + + @Override + public Vectors at(long... coordinates) { + positionIterator = Helpers.iterateByPosition(array, 1, coordinates); + return this; } @Override - public DoubleNdArrayHydrator.Scalars put(double scalar) { - buffer().setDouble(scalar, positionIterator.next()); + public Vectors put(double... vector) { + Helpers.validateVectorLength(vector.length, array.shape()); + array.buffer().offset(positionIterator.nextLong()).write(vector); return this; } - private ScalarsImpl(long[] coords) { - super(coords); + VectorsImpl(long[] coordinates) { + positionIterator = Helpers.iterateByPosition(array, 1, coordinates); } + + private PositionIterator positionIterator; } - private class VectorsImpl extends DenseNdArrayHydrator.VectorsImpl implements DoubleNdArrayHydrator.Vectors { + class ElementsImpl implements Elements { @Override - public DoubleNdArrayHydrator.Vectors at(long... coordinates) { - return super.at(coordinates); + public Elements at(long... coordinates) { + this.elementIterator = Helpers.iterateByElement(array, coordinates); + return this; } @Override - public DoubleNdArrayHydrator.Vectors put(double... vector) { - if (vector == null || vector.length > denseArray.shape().get(-1)) { - throw new IllegalArgumentException("Vector should not be null nor exceed " + denseArray.shape().get(-1) + " elements"); + public Elements put(DoubleNdArray element) { + if (element == null) { + throw new IllegalArgumentException("Element cannot be null"); } - buffer().offset(positionIterator.next()).write(vector); + element.copyTo(elementIterator.next()); return this; } - private VectorsImpl(long[] coords) { - super(coords); + ElementsImpl(long[] coordinates) { + this.elementIterator = Helpers.iterateByElement(array, coordinates); } + + private Iterator elementIterator; } + + private final DoubleDenseNdArray array; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/Helpers.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/Helpers.java new file mode 100644 index 0000000..5c503c9 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/Helpers.java @@ -0,0 +1,52 @@ +package org.tensorflow.ndarray.impl.dense.hydrator; + +import java.util.Arrays; +import java.util.Iterator; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArraySequence; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.impl.dense.AbstractDenseNdArray; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sequence.IndexedPositionIterator; +import org.tensorflow.ndarray.impl.sequence.PositionIterator; + +final class Helpers { + + static PositionIterator iterateByPosition(AbstractDenseNdArray array, int elementRank, long[] coords) { + DimensionalSpace dimensions = array.dimensions(); + int dimensionIdx = dimensions.numDimensions() - elementRank - 1; + if (dimensionIdx < 0) { + throw new IllegalArgumentException("Cannot hydrate array of shape " + array.shape() + " with elements of rank " + elementRank); + } + if (coords == null || coords.length == 0) { + return PositionIterator.create(dimensions, dimensionIdx); + } + if ((coords.length - 1) != dimensionIdx) { + throw new IllegalArgumentException(Arrays.toString(coords) + " are not valid coordinates for dimension " + + dimensionIdx + " in an array of shape " + dimensions.shape()); + } + return PositionIterator.create(dimensions, coords); + } + + static > Iterator iterateByElement(AbstractDenseNdArray array, long[] coords) { + DimensionalSpace dimensions = array.dimensions(); + int dimensionIdx; + if (coords == null || coords.length == 0) { + return array.elements(0).iterator(); + } + if (coords.length > dimensions.numDimensions()) { + throw new IllegalArgumentException(Arrays.toString(coords) + " are not valid coordinates for an array of shape " + dimensions.shape()); + } + return array.elementsAt(coords).iterator(); + } + + static void validateVectorLength(int length, Shape shape) { + if (length == 0) { + throw new IllegalArgumentException("Vector cannot be empty"); + } + if (length > shape.get(-1)) { + throw new IllegalArgumentException("Vector cannot exceed " + shape.get(-1) + " elements"); + } + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/FastElementSequence.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/FastElementSequence.java index 92cebeb..2430030 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/FastElementSequence.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/FastElementSequence.java @@ -34,8 +34,12 @@ public final class FastElementSequence> implements NdArraySequence { public FastElementSequence(AbstractNdArray ndArray, int dimensionIdx, U element, DataBufferWindow elementWindow) { + this(ndArray, new long[dimensionIdx + 1], element, elementWindow); + } + + public FastElementSequence(AbstractNdArray ndArray, long[] startCoords, U element, DataBufferWindow elementWindow) { this.ndArray = ndArray; - this.dimensionIdx = dimensionIdx; + this.startCoords = startCoords; this.element = element; this.elementWindow = elementWindow; } @@ -47,7 +51,7 @@ public Iterator iterator() { @Override public void forEachIndexed(BiConsumer consumer) { - PositionIterator.createIndexed(ndArray.dimensions(), dimensionIdx).forEachIndexed((long[] coords, long position) -> { + PositionIterator.createIndexed(ndArray.dimensions(), startCoords).forEachIndexed((long[] coords, long position) -> { elementWindow.slideTo(position); consumer.accept(coords, element); }); @@ -55,7 +59,7 @@ public void forEachIndexed(BiConsumer consumer) { @Override public NdArraySequence asSlices() { - return new SlicingElementSequence(ndArray, dimensionIdx); + return new SlicingElementSequence(ndArray, startCoords); } private class SequenceIterator implements Iterator { @@ -71,11 +75,11 @@ public U next() { return element; } - private final PositionIterator positionIterator = PositionIterator.create(ndArray.dimensions(), dimensionIdx); + private final PositionIterator positionIterator = PositionIterator.create(ndArray.dimensions(), startCoords); } private final AbstractNdArray ndArray; - private final int dimensionIdx; + private final long[] startCoords; private final U element; private final DataBufferWindow elementWindow; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java index 4bb84ed..c7a3dcf 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java @@ -17,6 +17,8 @@ package org.tensorflow.ndarray.impl.sequence; +import java.util.Arrays; + import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; class IndexedSequentialPositionIterator extends SequentialPositionIterator implements IndexedPositionIterator { @@ -24,27 +26,28 @@ class IndexedSequentialPositionIterator extends SequentialPositionIterator imple @Override public void forEachIndexed(CoordsLongConsumer consumer) { while (hasNext()) { - consumer.consume(coords, nextLong()); - incrementCoords(); + consumer.consume(coords, super.nextLong()); + dimensions.incrementCoordinates(coords); } } - private void incrementCoords() { - for (int i = coords.length - 1; i >= 0; --i) { - if (coords[i] < shape[i] - 1) { - coords[i] += 1L; - return; - } - coords[i] = 0L; - } + @Override + public long nextLong() { + long tmp = super.nextLong(); + dimensions.incrementCoordinates(coords); + return tmp; } IndexedSequentialPositionIterator(DimensionalSpace dimensions, int dimensionIdx) { - super(dimensions, dimensionIdx); - this.shape = dimensions.shape().asArray(); - this.coords = new long[dimensionIdx + 1]; + this(dimensions, new long[dimensionIdx + 1]); + } + + IndexedSequentialPositionIterator(DimensionalSpace dimensions, long[] coords) { + super(dimensions, coords); + this.dimensions = dimensions; + this.coords = Arrays.copyOf(coords, coords.length); } - private final long[] shape; - private long[] coords; + private final DimensionalSpace dimensions; + private final long[] coords; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java index dbbd59b..7dc2f96 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java @@ -17,6 +17,7 @@ package org.tensorflow.ndarray.impl.sequence; +import java.util.Arrays; import java.util.NoSuchElementException; import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; @@ -33,7 +34,7 @@ public long nextLong() { throw new NoSuchElementException(); } long position = dimensions.positionOf(coords); - increment(); + incrementCoords(); return position; } @@ -41,33 +42,23 @@ public long nextLong() { public void forEachIndexed(CoordsLongConsumer consumer) { while (hasNext()) { consumer.consume(coords, dimensions.positionOf(coords)); - increment(); + incrementCoords(); } } - private void increment() { - if (!increment(coords, dimensions)) { + private void incrementCoords() { + if (!dimensions.incrementCoordinates(coords)) { coords = null; } } - static boolean increment(long[] coords, DimensionalSpace dimensions) { - for (int i = coords.length - 1; i >= 0; --i) { - if ((coords[i] = (coords[i] + 1) % dimensions.get(i).numElements()) > 0) { - return true; - } - } - return false; - } - NdPositionIterator(DimensionalSpace dimensions, int dimensionIdx) { - this.dimensions = dimensions; - this.coords = new long[dimensionIdx + 1]; + this(dimensions, new long[dimensionIdx + 1]); } NdPositionIterator(DimensionalSpace dimensions, long[] coords) { this.dimensions = dimensions; - this.coords = coords; + this.coords = Arrays.copyOf(coords, coords.length); } private final DimensionalSpace dimensions; diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java index 3a7bd77..a30c31c 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java @@ -31,9 +31,6 @@ static PositionIterator create(DimensionalSpace dimensions, int dimensionIdx) { } static PositionIterator create(DimensionalSpace dimensions, long... startCoords) { - if (startCoords == null) { - throw new IllegalArgumentException(); - } if (dimensions.isSegmented()) { return new NdPositionIterator(dimensions, startCoords); } @@ -47,6 +44,13 @@ static IndexedPositionIterator createIndexed(DimensionalSpace dimensions, int di return new IndexedSequentialPositionIterator(dimensions, dimensionIdx); } + static IndexedPositionIterator createIndexed(DimensionalSpace dimensions, long... startCoords) { + if (dimensions.isSegmented()) { + return new NdPositionIterator(dimensions, startCoords); + } + return new IndexedSequentialPositionIterator(dimensions, startCoords); + } + static PositionIterator sequence(long stride, long end) { return new SequentialPositionIterator(stride, end); } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SlicingElementSequence.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SlicingElementSequence.java index 6fe8398..6d11968 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SlicingElementSequence.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SlicingElementSequence.java @@ -33,18 +33,26 @@ public final class SlicingElementSequence> implements NdArraySequence { public SlicingElementSequence(AbstractNdArray ndArray, int dimensionIdx) { - this(ndArray, dimensionIdx, ndArray.dimensions().from(dimensionIdx + 1)); + this(ndArray, new long[dimensionIdx + 1]); + } + + public SlicingElementSequence(AbstractNdArray ndArray, long[] startCoords) { + this(ndArray, startCoords, ndArray.dimensions().from(startCoords.length)); } public SlicingElementSequence(AbstractNdArray ndArray, int dimensionIdx, DimensionalSpace elementDimensions) { + this(ndArray, new long[dimensionIdx + 1], elementDimensions); + } + + public SlicingElementSequence(AbstractNdArray ndArray, long[] startCoords, DimensionalSpace elementDimensions) { this.ndArray = ndArray; - this.dimensionIdx = dimensionIdx; + this.startCoords = startCoords; this.elementDimensions = elementDimensions; } @Override public Iterator iterator() { - PositionIterator positionIterator = PositionIterator.create(ndArray.dimensions(), dimensionIdx); + PositionIterator positionIterator = PositionIterator.create(ndArray.dimensions(), startCoords); return new Iterator() { @Override @@ -61,7 +69,7 @@ public U next() { @Override public void forEachIndexed(BiConsumer consumer) { - PositionIterator.createIndexed(ndArray.dimensions(), dimensionIdx).forEachIndexed((long[] coords, long position) -> + PositionIterator.createIndexed(ndArray.dimensions(), startCoords).forEachIndexed((long[] coords, long position) -> consumer.accept(coords, ndArray.slice(position, elementDimensions)) ); } @@ -72,6 +80,6 @@ public NdArraySequence asSlices() { } private final AbstractNdArray ndArray; - private final int dimensionIdx; + private final long[] startCoords; private final DimensionalSpace elementDimensions; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydrator.java index fb56fc4..f99631c 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydrator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydrator.java @@ -1,75 +1,126 @@ package org.tensorflow.ndarray.impl.sparse.hydrator; +import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; +import org.tensorflow.ndarray.hydrator.NdArrayHydrator; import org.tensorflow.ndarray.impl.sparse.DoubleSparseNdArray; -public class DoubleSparseNdArrayHydrator extends SparseNdArrayHydrator implements DoubleNdArrayHydrator { +public class DoubleSparseNdArrayHydrator implements DoubleNdArrayHydrator { public DoubleSparseNdArrayHydrator(DoubleSparseNdArray array) { - super(array); + this.array = array; } @Override - public DoubleNdArrayHydrator.Scalars byScalars(long... coordinates) { + public Scalars byScalars(long... coordinates) { return new ScalarsImpl(coordinates); } @Override - public DoubleNdArrayHydrator.Vectors byVectors(long... coordinates) { + public Vectors byVectors(long... coordinates) { return new VectorsImpl(coordinates); } - private class ScalarsImpl extends SparseNdArrayHydrator.ScalarsImpl implements DoubleNdArrayHydrator.Scalars { + @Override + public Elements byElements(long... coordinates) { + return new ElementsImpl(coordinates); + } + + @Override + public NdArrayHydrator boxed() { + return new SparseNdArrayHydrator(array); + } + + private class ScalarsImpl implements Scalars { @Override - public DoubleNdArrayHydrator.Scalars at(long... coordinates) { - return super.at(coordinates); + public Scalars at(long... coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, 0); + return this; } @Override - public DoubleNdArrayHydrator.Scalars put(double scalar) { - sparseArray().getValues().setDouble(scalar, index); - sparseArray().getIndices().set(NdArrays.vectorOf(coordinates.coords), index++); - coordinates.increment(); + public Scalars put(double scalar) { + addValue(scalar, coordinates); + array.dimensions().incrementCoordinates(coordinates); return this; } - private ScalarsImpl(long[] coords) { - super(coords); + private long[] coordinates; + + private ScalarsImpl(long[] coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, 0); } } - private class VectorsImpl extends SparseNdArrayHydrator.VectorsImpl implements DoubleNdArrayHydrator.Vectors { + private class VectorsImpl implements Vectors { @Override - public DoubleNdArrayHydrator.Vectors at(long... coordinates) { - return super.at(coordinates); + public Vectors at(long... coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, 1); + return this; } @Override - public DoubleNdArrayHydrator.Vectors put(double... vector) { - if (vector == null || vector.length > sparseArray().shape().get(-1)) { - throw new IllegalArgumentException("Vector should not be null nor exceed " + sparseArray().shape().get(-1) + " elements"); + public Vectors put(double... vector) { + if (vector.length == 0 || vector.length > array.shape().get(-1)) { + throw new IllegalArgumentException("Vector cannot be null nor exceed " + array.shape().get(-1) + " elements"); } - double defaultValue = sparseArray().getDefaultValue(); - for (double value : vector) { - if (value != defaultValue) { - sparseArray().getValues().setDouble(value, index); - sparseArray().getIndices().set(NdArrays.vectorOf(coordinates.coords), index++); - } - coordinates.increment(); + for (int i = 0; i < vector.length; ++i) { + addValue(vector[i], coordinates, i); } + array.dimensions().incrementCoordinates(coordinates); return this; } - private VectorsImpl(long[] coords) { - super(coords); + private long[] coordinates; + + private VectorsImpl(long[] coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, 1); } } - @Override - protected DoubleSparseNdArray sparseArray() { - return super.sparseArray(); + private class ElementsImpl implements Elements { + + @Override + public Elements at(long... coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates); + return this; + } + + @Override + public Elements put(DoubleNdArray element) { + if (element == null) { + throw new IllegalArgumentException("Element cannot be null"); + } + if (element.shape().isScalar()) { + addValue(element.getDouble(), coordinates); + } else { + element.scalars().forEachIndexed((scalarCoords, scalar) -> { + addValue(scalar.getDouble(), coordinates, scalarCoords); + }); + } + array.dimensions().incrementCoordinates(coordinates); + return this; + } + + private long[] coordinates; + + private ElementsImpl(long[] coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates); + } + } + + private final DoubleSparseNdArray array; + private long valueCount = 0; + + private void addValue(double value, long[] origin, long... coords) { + if (value != array.getDefaultValue()) { + array.getValues().setDouble(value, valueCount); + Helpers.writeValueCoords(array, valueCount, origin, coords); + ++valueCount; + } } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/Helpers.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/Helpers.java new file mode 100644 index 0000000..440855f --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/Helpers.java @@ -0,0 +1,49 @@ +package org.tensorflow.ndarray.impl.sparse.hydrator; + +import java.util.Arrays; + +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; + +final class Helpers { + + static long[] validateCoordinates(AbstractSparseNdArray array, long[] coords, int elementRank) { + DimensionalSpace dimensions = array.dimensions(); + int dimensionIdx = 0; + if (elementRank >= 0) { + dimensionIdx = dimensions.numDimensions() - elementRank - 1; + if (dimensionIdx < 0) { + throw new IllegalArgumentException("Cannot hydrate array of shape " + array.shape() + " with elements of rank " + elementRank); + } + } + if (coords == null || coords.length == 0) { + return new long[dimensionIdx + 1]; + } + if ((coords.length - 1) != dimensionIdx) { + throw new IllegalArgumentException(Arrays.toString(coords) + " are not valid coordinates for dimension " + + dimensionIdx + " in an array of shape " + dimensions.shape()); + } + return Arrays.copyOf(coords, coords.length); + } + + static long[] validateCoordinates(AbstractSparseNdArray array, long[] coords) { + if (coords == null || coords.length == 0) { + return new long[1]; + } + int dimensionIdx = array.shape().numDimensions() - coords.length; + if (dimensionIdx < 0) { + throw new IllegalArgumentException("Cannot hydrate array of shape " + array.shape() + " with elements of rank " + (coords.length - 1)); + } + return Arrays.copyOf(coords, coords.length); + } + + static void writeValueCoords(AbstractSparseNdArray array, long valueIndex, long[] origin, long[] coords) { + int coordsIndex = 0; + for (long c: origin) { + array.getIndices().setLong(c, valueIndex, coordsIndex++); + } + for (long c: coords) { + array.getIndices().setLong(c, valueIndex, coordsIndex++); + } + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydrator.java index aaf25df..5d183f5 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydrator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydrator.java @@ -1,16 +1,14 @@ package org.tensorflow.ndarray.impl.sparse.hydrator; -import java.util.Arrays; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.hydrator.NdArrayHydrator; -import org.tensorflow.ndarray.impl.sequence.CoordinatesIncrementor; import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; -class SparseNdArrayHydrator implements NdArrayHydrator { +public class SparseNdArrayHydrator implements NdArrayHydrator { public SparseNdArrayHydrator(AbstractSparseNdArray array) { - this.sparseArray = array; + this.array = array; } @Override @@ -28,119 +26,95 @@ public Elements byElements(long... coordinates) { return new ElementsImpl(coordinates); } - protected class ScalarsImpl implements Scalars { + private class ScalarsImpl implements Scalars { @Override - public > U at(long... coordinates) { - if (coordinates == null || coordinates.length != sparseArray.shape().numDimensions()) { - throw new IllegalArgumentException(Arrays.toString(coordinates) + " are not valid scalar coordinates for an array of shape " + sparseArray - .shape()); - } - this.coordinates = new CoordinatesIncrementor(sparseArray.shape().asArray(), coordinates); - return (U) this; + public Scalars at(long... coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, 0); + return this; } @Override - public > U putObject(T scalar) { - sparseArray.getValues().setObject(scalar, index); - sparseArray.getIndices().set(NdArrays.vectorOf(coordinates.coords), index++); - coordinates.increment(); - return (U) this; + public Scalars put(T scalar) { + if (scalar == null) { + throw new IllegalArgumentException("Scalar cannot be null"); + } + if (scalar != array.getDefaultValue()) { + array.getValues().setObject(scalar, index); + array.getIndices().set(NdArrays.vectorOf(coordinates), index++); + } + array.dimensions().incrementCoordinates(coordinates); + return this; } - protected ScalarsImpl(long[] coords) { - if (coords == null || coords.length == 0) { - coordinates = new CoordinatesIncrementor(sparseArray.shape().asArray(), sparseArray.shape().numDimensions() - 1); - } else { - at(coords); - } + protected ScalarsImpl(long[] coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, 0); } - protected CoordinatesIncrementor coordinates; + protected long[] coordinates; } - protected class VectorsImpl implements Vectors { + private class VectorsImpl implements Vectors { @Override - public > U at(long... coordinates) { - if (coordinates == null || coordinates.length != sparseArray.shape().numDimensions() - 1) { - throw new IllegalArgumentException(Arrays.toString(coordinates) + " are not valid vector coordinates for an array of shape " + sparseArray - .shape()); - } - this.coordinates = new CoordinatesIncrementor(sparseArray.shape().asArray(), Arrays.copyOf(coordinates, sparseArray.shape().numDimensions())); - return (U) this; + public Vectors at(long... coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, 1); + return this; } @Override - public > U putObjects(T... vector) { - if (vector == null || vector.length > sparseArray.shape().get(-1)) { - throw new IllegalArgumentException("Vector should not be null nor exceed " + sparseArray.shape().get(-1) + " elements"); + public Vectors put(T... vector) { + if (vector.length == 0 || vector.length > array.shape().get(-1)) { + throw new IllegalArgumentException("Vector cannot be null nor exceed " + array.shape().get(-1) + " elements"); } for (T value : vector) { - if (value != sparseArray.getDefaultValue()) { - sparseArray.getValues().setObject(value, index); - sparseArray.getIndices().set(NdArrays.vectorOf(coordinates.coords), index++); + if (value != array.getDefaultValue()) { + array.getValues().setObject(value, index); + array.getIndices().set(NdArrays.vectorOf(coordinates), index++); } - coordinates.increment(); + array.dimensions().incrementCoordinates(coordinates); } - return (U) this; + return this; } - protected VectorsImpl(long[] coords) { - if (sparseArray.shape().numDimensions() < 1) { - throw new IllegalArgumentException("Cannot hydrate a scalar with vectors"); - } - if (coords == null || coords.length == 0) { - coordinates = new CoordinatesIncrementor(sparseArray.shape().asArray(), sparseArray.shape().numDimensions() - 1); - } else { - at(coords); - } + protected VectorsImpl(long[] coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, 0); } - protected CoordinatesIncrementor coordinates; + protected long[] coordinates; } - protected class ElementsImpl implements Elements { + private class ElementsImpl implements Elements { @Override - public > U at(long... coordinates) { - if (coordinates == null || coordinates.length == 0 || coordinates.length > sparseArray.shape().numDimensions()) { - throw new IllegalArgumentException(Arrays.toString(coordinates) + " are not valid coordinates for an array of shape " + sparseArray - .shape()); - } - this.coordinates = new CoordinatesIncrementor(sparseArray.shape().asArray(), Arrays.copyOf(coordinates, sparseArray.shape().numDimensions())); - return (U) this; + public Elements at(long... coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, coordinates.length - 1); + return this; } @Override - public > U put(NdArray array) { - array.scalars().forEach(s -> { + public Elements put(NdArray element) { + if (element == null) { + throw new IllegalArgumentException("Array cannot be null"); + } + element.scalars().forEach(s -> { T value = s.getObject(); - if (value != sparseArray.getDefaultValue()) { - sparseArray.getValues().setObject(value, index); - sparseArray.getIndices().set(NdArrays.vectorOf(coordinates.coords), index++); + if (value != array.getDefaultValue()) { + array.getValues().setObject(value, index); + array.getIndices().set(NdArrays.vectorOf(coordinates), index++); } - coordinates.increment(); + array.dimensions().incrementCoordinates(coordinates); }); - return (U) this; + return this; } - protected ElementsImpl(long[] coords) { - if (coords == null || coords.length == 0) { - this.coordinates = new CoordinatesIncrementor(sparseArray.shape().asArray(), sparseArray.shape().numDimensions() - 1); - } else { - at(coords); - } + protected ElementsImpl(long[] coordinates) { + this.coordinates = Helpers.validateCoordinates(array, coordinates, coordinates.length - 1); } - protected CoordinatesIncrementor coordinates; - } - - protected long index = 0; - - protected > U sparseArray() { - return (U) sparseArray; + protected long[] coordinates; } - private final AbstractSparseNdArray sparseArray; + private final AbstractSparseNdArray array; + private long index = 0; } diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydratorTestBase.java b/ndarray/src/test/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydratorTestBase.java new file mode 100644 index 0000000..69b38f9 --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydratorTestBase.java @@ -0,0 +1,141 @@ +package org.tensorflow.ndarray.hydrator; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import java.util.function.Consumer; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; + +public abstract class DoubleNdArrayHydratorTestBase { + + protected abstract DoubleNdArray newArray(Shape shape, long numValues, Consumer hydrate); + + @Test + public void hydrateNdArrayByScalars() { + DoubleNdArray array = newArray(Shape.of(3, 2, 3), 14, hydrator -> { + hydrator + .byScalars() + .put(0.0) + .put(0.1) + .put(0.2) + .put(0.3) + .put(0.4) + .put(0.5) + .put(1.0) + .put(1.1) + .put(1.2) + .at(2, 0, 0) + .put(2.0) + .put(2.1) + .put(2.2) + .put(2.3) + .put(2.4) + .put(2.5); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][]{ + {{0.0, 0.1, 0.2}, {0.3, 0.4, 0.5}}, + {{1.0, 1.1, 1.2}, {0.0, 0.0, 0.0}}, + {{2.0, 2.1, 2.2}, {2.3, 2.4, 2.5}} + }), array); + + array = newArray(Shape.of(3, 2), 4, hydrator -> { + hydrator + .byScalars() + .put(10.0) + .put(20.0) + .put(30.0) + .at(2, 1) + .put(40.0); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][]{{10.0, 20.0}, {30.0, 0.0}, {0.0, 40.0}}), array); + } + + @Test + public void hydrateNdArrayByVectors() { + DoubleNdArray array = newArray(Shape.of(3, 2, 3), 14, hydrator -> { + hydrator + .byVectors() + .put(0.0, 0.1, 0.2) + .put(0.3, 0.4, 0.5) + .put(1.0, 1.1, 1.2) + .at(2, 0) + .put(2.0, 2.1, 2.2) + .put(2.3, 2.4, 2.5); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][]{ + {{0.0, 0.1, 0.2}, {0.3, 0.4, 0.5}}, + {{1.0, 1.1, 1.2}, {0.0, 0.0, 0.0}}, + {{2.0, 2.1, 2.2}, {2.3, 2.4, 2.5}} + }), array); + + array = newArray(Shape.of(3, 2), 5, hydrator -> { + hydrator + .byVectors() + .put(10.0, 20.0) + .put(30.0) + .at(2) + .put(40.0, 50.0); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][]{{10.0, 20.0}, {30.0, 0.0}, {40.0, 50.0}}), array); + } + + @Test + public void vectorCannotBeEmpty() { + try { + newArray(Shape.of(3, 2), 1, hydrator -> hydrator.byVectors().put()); + fail(); + } catch (IllegalArgumentException e) { + // ok + } + } + + @Test + public void hydrateNdArrayByElements() { + DoubleNdArray array = newArray(Shape.of(3, 2, 3), 14, hydrator -> { + hydrator + .byElements() + .put(StdArrays.ndCopyOf(new double[][]{ + {0.0, 0.1, 0.2}, + {0.3, 0.4, 0.5} + })) + .at(1, 0) + .put(NdArrays.vectorOf(1.0, 1.1, 1.2)) + .at(2) + .put(StdArrays.ndCopyOf(new double[][]{ + {2.0, 2.1, 2.2}, + {2.3, 2.4, 2.5} + })); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][][]{ + {{0.0, 0.1, 0.2}, {0.3, 0.4, 0.5}}, + {{1.0, 1.1, 1.2}, {0.0, 0.0, 0.0}}, + {{2.0, 2.1, 2.2}, {2.3, 2.4, 2.5}} + }), array); + + DoubleNdArray vector = NdArrays.vectorOf(10.0, 20.0); + DoubleNdArray scalar = NdArrays.scalarOf(30.0); + + array = newArray(Shape.of(4, 2), 7, hydrator -> { + hydrator + .byElements() + .put(vector) + .put(vector) + .at(2, 1) + .put(scalar) + .at(3) + .put(vector); + }); + + assertEquals(StdArrays.ndCopyOf(new double[][]{{10.0, 20.0}, {10.0, 20.0}, {0.0, 30.0}, {10.0, 20.0}}), array); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydratorTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydratorTest.java deleted file mode 100644 index ee4df4a..0000000 --- a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydratorTest.java +++ /dev/null @@ -1,87 +0,0 @@ -package org.tensorflow.ndarray.impl.dense.hydrator; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import org.junit.jupiter.api.Test; -import org.tensorflow.ndarray.DoubleNdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.StdArrays; -import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; -import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; - -public class DenseNdArrayHydratorTest { - - @Test - public void hydrateNdArrayByScalars() { - DoubleNdArray array = NdArrays.ofDoubles(Shape.of(3, 2, 3), hydrator -> { - hydrator - .byScalars() - .put(0.0) - .put(0.1) - .put(0.2) - .put(0.3) - .put(0.4) - .put(0.5) - .put(1.0) - .put(1.1) - .put(1.2) - .at(2, 0, 0) - .put(2.0) - .put(2.1) - .put(2.2) - .put(2.3) - .put(2.4) - .put(2.5); - }); - - assertEquals(StdArrays.ndCopyOf(new double[][][] { - {{ 0.0, 0.1, 0.2 }, { 0.3, 0.4, 0.5 }}, - {{ 1.0, 1.1, 1.2 }, { 0.0, 0.0, 0.0 }}, - {{ 2.0, 2.1, 2.2 }, { 2.3, 2.4, 2.5 }} - }), array); - } - - @Test - public void hydrateNdArrayByVectors() { - DoubleNdArray array = NdArrays.ofDoubles(Shape.of(3, 2, 3), hydrator -> { - hydrator.byVectors() - .put(0.0, 0.1, 0.2) - .put(0.3, 0.4, 0.5) - .put(1.0, 1.1, 1.2) - .at(2, 0) - .put(2.0, 2.1, 2.2) - .put(2.3, 2.4, 2.5); - }); - - assertEquals(StdArrays.ndCopyOf(new double[][][] { - {{ 0.0, 0.1, 0.2 }, { 0.3, 0.4, 0.5 }}, - {{ 1.0, 1.1, 1.2 }, { 0.0, 0.0, 0.0 }}, - {{ 2.0, 2.1, 2.2 }, { 2.3, 2.4, 2.5 }} - }), array); - } - - @Test - public void hydrateNdArrayByElements() { - DoubleNdArray array = NdArrays.ofDoubles(Shape.of(3, 2, 3), hydrator -> { - hydrator.byElements() - .put(StdArrays.ndCopyOf(new double[][] { - { 0.0, 0.1, 0.2 }, - { 0.3, 0.4, 0.5 } - })) - .at(1, 0) - .put(NdArrays.vectorOf(1.0, 1.1, 1.2)) - .at(2) - .put(StdArrays.ndCopyOf(new double[][] { - { 2.0, 2.1, 2.2 }, - { 2.3, 2.4, 2.5 } - })); - }); - - assertEquals(StdArrays.ndCopyOf(new double[][][] { - {{ 0.0, 0.1, 0.2 }, { 0.3, 0.4, 0.5 }}, - {{ 1.0, 1.1, 1.2 }, { 0.0, 0.0, 0.0 }}, - {{ 2.0, 2.1, 2.2 }, { 2.3, 2.4, 2.5 }} - }), array); - } -} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydratorTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydratorTest.java new file mode 100644 index 0000000..63badf4 --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydratorTest.java @@ -0,0 +1,24 @@ +package org.tensorflow.ndarray.impl.dense.hydrator; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.function.Consumer; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; +import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydratorTestBase; +import org.tensorflow.ndarray.hydrator.NdArrayHydrator; +import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; + +public class DoubleDenseNdArrayHydratorTest extends DoubleNdArrayHydratorTestBase { + + @Override + protected DoubleNdArray newArray(Shape shape, long numValues, Consumer hydrate) { + return NdArrays.ofDoubles(shape, hydrate); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sequence/ElementSequenceTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sequence/ElementSequenceTest.java index bad7840..3ad462d 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sequence/ElementSequenceTest.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sequence/ElementSequenceTest.java @@ -98,7 +98,7 @@ public void slicingElementSequenceReturnsUniqueInstances() { public void fastElementSequenceReturnsSameInstance() { IntNdArray array = NdArrays.ofInts(Shape.of(2, 3, 2)); IntNdArray element = array.get(0); - NdArraySequence sequence = new FastElementSequence( + NdArraySequence sequence = new FastElementSequence( (AbstractNdArray) array, 1, element, mockDataBufferWindow(2)); sequence.forEach(e -> { if (e != element) { diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydratorTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydratorTest.java new file mode 100644 index 0000000..d76922d --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydratorTest.java @@ -0,0 +1,23 @@ +package org.tensorflow.ndarray.impl.sparse.hydrator; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.function.Consumer; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; +import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydratorTestBase; +import org.tensorflow.ndarray.hydrator.NdArrayHydrator; + +public class DoubleSparseNdArrayHydratorTest extends DoubleNdArrayHydratorTestBase { + + @Override + protected DoubleNdArray newArray(Shape shape, long numValues, Consumer hydrate) { + return NdArrays.sparseOfDoubles(shape, numValues, hydrate); + } +} From 5bbe812081cbf3bbd73d90e3d34fcc46d3479b6f Mon Sep 17 00:00:00 2001 From: karllessard Date: Sun, 4 Dec 2022 22:50:30 -0500 Subject: [PATCH 3/3] Specialized initializers inherit from generic ones --- ndarray/src/main/java/module-info.java | 1 + .../tensorflow/ndarray/NdArraySequence.java | 6 +- .../java/org/tensorflow/ndarray/NdArrays.java | 44 ++-- .../hydrator/DoubleNdArrayHydrator.java | 172 -------------- .../ndarray/hydrator/NdArrayHydrator.java | 177 -------------- .../tensorflow/ndarray/impl/Validator.java | 14 +- .../dense/hydrator/DenseNdArrayHydrator.java | 101 -------- .../hydrator/DoubleDenseNdArrayHydrator.java | 104 -------- .../ndarray/impl/dense/hydrator/Helpers.java | 52 ---- .../BaseDenseNdArrayInitializer.java | 134 +++++++++++ .../initializer/DenseNdArrayInitializer.java | 63 +++++ .../DoubleDenseNdArrayInitializer.java | 95 ++++++++ .../impl/dimension/DimensionalSpace.java | 9 + .../AbstractNdArrayInitializer.java | 101 ++++++++ .../impl/sequence/CoordinatesIncrementor.java | 48 ---- .../IndexedSequentialPositionIterator.java | 4 +- .../impl/sequence/NdPositionIterator.java | 2 +- .../impl/sequence/PositionIterator.java | 4 +- .../sequence/SequentialPositionIterator.java | 4 +- .../impl/sparse/AbstractSparseNdArray.java | 11 +- .../hydrator/DoubleSparseNdArrayHydrator.java | 126 ---------- .../ndarray/impl/sparse/hydrator/Helpers.java | 49 ---- .../hydrator/SparseNdArrayHydrator.java | 120 ---------- .../BaseSparseNdArrayInitializer.java | 133 +++++++++++ .../DoubleSparseNdArrayInitializer.java | 100 ++++++++ .../initializer/SparseNdArrayInitializer.java | 62 +++++ .../initializer/BaseNdArrayInitializer.java | 223 ++++++++++++++++++ .../initializer/DoubleNdArrayInitializer.java | 76 ++++++ .../initializer/NdArrayInitializer.java | 55 +++++ .../DoubleDenseNdArrayHydratorTest.java | 24 -- .../DoubleDenseNdArrayInitializerTest.java | 33 +++ .../StringDenseNdArrayInitializerTest.java | 33 +++ .../AbstractNdArrayInitializerTest.java | 68 ++++++ .../DoubleNdArrayInitializerTestBase.java} | 99 ++++---- .../StringNdArrayInitializerTestBase.java | 146 ++++++++++++ .../DoubleSparseNdArrayHydratorTest.java | 23 -- .../hydrator/SparseNdArrayHydratorTest.java | 85 ------- .../DoubleSparseNdArrayInitializerTest.java | 33 +++ .../StringSparseNdArrayInitializerTest.java | 33 +++ 39 files changed, 1494 insertions(+), 1173 deletions(-) delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydrator.java delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/hydrator/NdArrayHydrator.java delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydrator.java delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydrator.java delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/Helpers.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/BaseDenseNdArrayInitializer.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/DenseNdArrayInitializer.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/DoubleDenseNdArrayInitializer.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/initializer/AbstractNdArrayInitializer.java delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydrator.java delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/Helpers.java delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydrator.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/BaseSparseNdArrayInitializer.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/DoubleSparseNdArrayInitializer.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/SparseNdArrayInitializer.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/initializer/BaseNdArrayInitializer.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/initializer/DoubleNdArrayInitializer.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/initializer/NdArrayInitializer.java delete mode 100644 ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydratorTest.java create mode 100644 ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/initializer/DoubleDenseNdArrayInitializerTest.java create mode 100644 ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/initializer/StringDenseNdArrayInitializerTest.java create mode 100644 ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/AbstractNdArrayInitializerTest.java rename ndarray/src/test/java/org/tensorflow/ndarray/{hydrator/DoubleNdArrayHydratorTestBase.java => impl/initializer/DoubleNdArrayInitializerTestBase.java} (55%) create mode 100644 ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/StringNdArrayInitializerTestBase.java delete mode 100644 ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydratorTest.java delete mode 100644 ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydratorTest.java create mode 100644 ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/initializer/DoubleSparseNdArrayInitializerTest.java create mode 100644 ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/initializer/StringSparseNdArrayInitializerTest.java diff --git a/ndarray/src/main/java/module-info.java b/ndarray/src/main/java/module-info.java index 6b33ed6..10d0b84 100644 --- a/ndarray/src/main/java/module-info.java +++ b/ndarray/src/main/java/module-info.java @@ -21,6 +21,7 @@ exports org.tensorflow.ndarray.buffer; exports org.tensorflow.ndarray.buffer.layout; exports org.tensorflow.ndarray.index; + exports org.tensorflow.ndarray.initializer; // Expose all implementions of our interfaces, so consumers can write custom // implementations easily by extending from them diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/NdArraySequence.java b/ndarray/src/main/java/org/tensorflow/ndarray/NdArraySequence.java index afb930e..b97639a 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/NdArraySequence.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/NdArraySequence.java @@ -33,12 +33,12 @@ public interface NdArraySequence> extends Iterable { /** - * Visit each elements of this iteration and their respective coordinates. + * Visit each element of this iteration and their respective coordinates. * *

Important: the consumer method should not keep a reference to the coordinates * as they might be mutable and reused during the iteration to improve performance. * - * @param consumer method to invoke for each elements + * @param consumer method to invoke for each element */ void forEachIndexed(BiConsumer consumer); @@ -60,7 +60,7 @@ public interface NdArraySequence> extends Iterable { * ndArray.elements(0).asSlices().forEach(e -> vectors::add); // Safe, each `e` is a distinct NdArray instance * } * - * @return a sequence that returns each elements iterated as a new slice + * @return a sequence that returns each element iterated as a new slice * @see DataBufferWindow */ NdArraySequence asSlices(); diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java b/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java index bfbe939..aa4e736 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java @@ -16,7 +16,6 @@ */ package org.tensorflow.ndarray; -import java.util.function.Consumer; import org.tensorflow.ndarray.buffer.BooleanDataBuffer; import org.tensorflow.ndarray.buffer.ByteDataBuffer; import org.tensorflow.ndarray.buffer.DataBuffer; @@ -26,9 +25,6 @@ import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.buffer.LongDataBuffer; import org.tensorflow.ndarray.buffer.ShortDataBuffer; -import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; -import org.tensorflow.ndarray.hydrator.NdArrayHydrator; -import org.tensorflow.ndarray.impl.dense.AbstractDenseNdArray; import org.tensorflow.ndarray.impl.dense.BooleanDenseNdArray; import org.tensorflow.ndarray.impl.dense.ByteDenseNdArray; import org.tensorflow.ndarray.impl.dense.DenseNdArray; @@ -37,8 +33,8 @@ import org.tensorflow.ndarray.impl.dense.IntDenseNdArray; import org.tensorflow.ndarray.impl.dense.LongDenseNdArray; import org.tensorflow.ndarray.impl.dense.ShortDenseNdArray; -import org.tensorflow.ndarray.impl.dense.hydrator.DenseNdArrayHydrator; -import org.tensorflow.ndarray.impl.dense.hydrator.DoubleDenseNdArrayHydrator; +import org.tensorflow.ndarray.impl.dense.initializer.DenseNdArrayInitializer; +import org.tensorflow.ndarray.impl.dense.initializer.DoubleDenseNdArrayInitializer; import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; import org.tensorflow.ndarray.impl.sparse.BooleanSparseNdArray; @@ -48,8 +44,12 @@ import org.tensorflow.ndarray.impl.sparse.IntSparseNdArray; import org.tensorflow.ndarray.impl.sparse.LongSparseNdArray; import org.tensorflow.ndarray.impl.sparse.ShortSparseNdArray; -import org.tensorflow.ndarray.impl.sparse.hydrator.DoubleSparseNdArrayHydrator; -import org.tensorflow.ndarray.impl.sparse.hydrator.SparseNdArrayHydrator; +import org.tensorflow.ndarray.impl.sparse.initializer.DoubleSparseNdArrayInitializer; +import org.tensorflow.ndarray.impl.sparse.initializer.SparseNdArrayInitializer; +import org.tensorflow.ndarray.initializer.DoubleNdArrayInitializer; +import org.tensorflow.ndarray.initializer.NdArrayInitializer; + +import java.util.function.Consumer; /** Utility class for instantiating {@link NdArray} objects. */ public final class NdArrays { @@ -565,16 +565,16 @@ public static DoubleNdArray ofDoubles(Shape shape) { } /** - * Creates an N-dimensional array of doubles of the given shape, hydrating it with data after its allocation + * Creates an N-dimensional array of doubles of the given shape, initializing its data after allocation. * * @param shape shape of the array - * @param hydrate initialize the data of the created array, using a hydrator + * @param init invoked to initialize the data of the allocated array * @return new double N-dimensional array * @throws IllegalArgumentException if shape is null or has unknown dimensions */ - public static DoubleNdArray ofDoubles(Shape shape, Consumer hydrate) { + public static DoubleNdArray ofDoubles(Shape shape, Consumer init) { DoubleDenseNdArray array = (DoubleDenseNdArray)ofDoubles(shape); - hydrate.accept(new DoubleDenseNdArrayHydrator(array)); + init.accept(new DoubleDenseNdArrayInitializer(array)); return array; } @@ -600,11 +600,11 @@ public static DoubleNdArray wrap(Shape shape, DoubleDataBuffer buffer) { * @return new double N-dimensional array * @throws IllegalArgumentException if shape is null or has unknown dimensions */ - public static DoubleSparseNdArray sparseOfDoubles(Shape shape, long numValues, Consumer hydrate) { + public static DoubleSparseNdArray sparseOfDoubles(Shape shape, long numValues, Consumer hydrate) { LongNdArray indices = ofLongs(Shape.of(numValues, shape.numDimensions())); DoubleNdArray values = ofDoubles(Shape.of(numValues)); DoubleSparseNdArray array = DoubleSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); - hydrate.accept(new DoubleSparseNdArrayHydrator(array)); + hydrate.accept(new DoubleSparseNdArrayInitializer(array)); return array; } @@ -802,12 +802,13 @@ public static NdArray ofObjects(Class clazz, Shape shape) { * @param clazz class of the data to be stored in this array * @param shape shape of the array * @param hydrate initialize the data of the created array, using a hydrator + * @param type of object to store in this array * @return new N-dimensional array * @throws IllegalArgumentException if shape is null or has unknown dimensions */ - public static NdArray ofObjects(Class clazz, Shape shape, Consumer> hydrate) { - AbstractDenseNdArray array = (AbstractDenseNdArray)ofObjects(clazz, shape); - hydrate.accept(new DenseNdArrayHydrator(array)); + public static NdArray ofObjects(Class clazz, Shape shape, Consumer> hydrate) { + var array = (DenseNdArray)ofObjects(clazz, shape); + hydrate.accept(new DenseNdArrayInitializer<>(array)); return array; } @@ -826,20 +827,21 @@ public static NdArray wrap(Shape shape, DataBuffer buffer) { } /** - * Creates an Sparse array of objects of the given shape, hydrating it with data after its allocation + * Creates a Sparse array of objects of the given shape, hydrating it with data after its allocation * * @param type the class type represented by this sparse array. * @param shape shape of the array * @param numValues number of values actually set in the array, others defaulting to the zero value * @param hydrate initialize the data of the created array, using a hydrator + * @param type of object to store in this array * @return new N-dimensional array * @throws IllegalArgumentException if shape is null or has unknown dimensions */ - public static NdArray sparseOfObjects(Class type, Shape shape, long numValues, Consumer> hydrate) { + public static NdArray sparseOfObjects(Class type, Shape shape, long numValues, Consumer> hydrate) { LongNdArray indices = ofLongs(Shape.of(numValues, shape.numDimensions())); NdArray values = ofObjects(type, Shape.of(numValues)); AbstractSparseNdArray array = (AbstractSparseNdArray)sparseOfObjects(type, indices, values, shape); - hydrate.accept(new SparseNdArrayHydrator(array)); + hydrate.accept(new SparseNdArrayInitializer<>(array)); return array; } @@ -856,6 +858,7 @@ public static NdArray sparseOfObjects(Class type, Shape shape, long nu * values=["one", "two"]} specifies that element {@code [1,3,1]} of the sparse NdArray has a * value of "one", and element {@code [2,4,0]} of the NdArray has a value of "two"". All other * values are null. + * @param type of object to store in this array * @param shape the shape of the dense array represented by this sparse array. * @return the float sparse array. */ @@ -880,6 +883,7 @@ public static NdArray sparseOfObjects( * values are null. * @param defaultValue Scalar value to set for indices not specified in 'indices' * @param shape the shape of the dense array represented by this sparse array. + * @param type of object to store in this array * @return the float sparse array. */ public static NdArray sparseOfObjects( diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydrator.java deleted file mode 100644 index 9440e42..0000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydrator.java +++ /dev/null @@ -1,172 +0,0 @@ -package org.tensorflow.ndarray.hydrator; - -import org.tensorflow.ndarray.DoubleNdArray; - -/** - * Specialization of the {@link NdArrayHydrator} API for hydrating arrays of doubles. - * - * @see NdArrayHydrator - */ -public interface DoubleNdArrayHydrator { - - /** - * An API for hydrate an {@link DoubleNdArray} using scalar values - */ - interface Scalars { - - /** - * Position the hydrator to the given {@code coordinates} to write the next scalars. - * - * @param coordinates position in the hydrated array - * @return this API - * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of a scalar - */ - Scalars at(long... coordinates); - - /** - * Set a double value as the next scalar value in the hydrated array. - * - * @param scalar next scalar value - * @return this API - * @throws IllegalArgumentException if {@code scalar} is null - */ - Scalars put(double scalar); - } - - /** - * An API for hydrate an {@link DoubleNdArray} using vectors, i.e. a list of scalars - */ - interface Vectors { - - /** - * Position the hydrator to the given {@code coordinates} to write the next vectors. - * - * @param coordinates position in the hydrated array - * @return this API - * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of a vector - */ - Vectors at(long... coordinates); - - /** - * Set a list of doubles as the next vector in the hydrated array. - * - * @param vector next vector values - * @return this API - * @throws IllegalArgumentException if {@code vector} is empty or its length is greater than the size of the dimension - * {@code n-1}, given {@code n} the rank of the hydrated array - */ - Vectors put(double... vector); - } - - /** - * An API for hydrate an {@link DoubleNdArray} using n-dimensional elements (sub-arrays). - */ - interface Elements { - - /** - * Position the hydrator to the given {@code coordinates} to write the next elements. - * - * @param coordinates position in the hydrated array - * @return this API - * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of an element of the hydrated array - */ - Elements at(long... coordinates); - - /** - * Set a n-dimensional array of doubles as the next element in the hydrated array. - * - * @param element array containing the next element values - * @return this API - * @throws IllegalArgumentException if {@code element} is null or its shape is incompatible with the current hydrator position - */ - Elements put(DoubleNdArray element); - } - - /** - * Start to hydrate the targeted array with scalars. - * - * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, - * defaults to the first scalar of this array. - * - * Example of usage: - *

{@code
-   *    DoubleNdArray array = NdArrays.ofDoubles(Shape.of(3, 2), hydrator -> {
-   *        hydrator.byScalars()
-   *          .put(10.0)
-   *          .put(20.0)
-   *          .put(30.0)
-   *          .at(2, 1)
-   *          .put(40.0);
-   *    });
-   *    // -> [[10.0, 20.0], [30.0, 0.0], [0.0, 40.0]]
-   * }
- * - * @param coordinates position in the hydrated array to start from - * @return a {@link Scalars} instance - * @throws IllegalArgumentException if {@code coordinates} are set but are not one of a scalar - */ - Scalars byScalars(long... coordinates); - - /** - * Start to hydrate the targeted array with vectors. - * - * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, - * defaults to the first scalar of the first vector of this array. - * - * Example of usage: - *
{@code
-   *    DoubleNdArray array = NdArrays.ofDoubles(Shape.of(3, 2), hydrator -> {
-   *        hydrator.byVectors()
-   *          .put(10.0, 20.0)
-   *          .put(30.0)
-   *          .at(2)
-   *          .put(40.0, 50.0);
-   *    });
-   *    // -> [[10.0, 20.0], [30.0, null], [40.0, 50.0]]
-   * }
- * - * @param coordinates position in the hydrated array to start from - * @return a {@link Vectors} instance - * @throws IllegalArgumentException if hydrated array is of rank-0 or if {@code coordinates} are set but are not one of a vector - */ - Vectors byVectors(long... coordinates); - - /** - * Start to hydrate the targeted array with n-dimensional elements. - * - * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, - * defaults to the first element in the first (0) dimension of the hydrated array. - * - * Example of usage: - *
{@code
-   *    DoubleNdArray vector = NdArrays.vectorOf(10.0, 20.0);
-   *    DoubleNdArray scalar = NdArrays.scalarOf(30.0);
-   *
-   *    DoubleNdArray array = NdArrays.ofDoubles(Shape.of(4, 2), hydrator -> {
-   *        hydrator.byElements()
-   *          .put(vector)
-   *          .put(vector)
-   *          .at(2, 1)
-   *          .put(scalar)
-   *          .at(3)
-   *          .put(vector);
-   *    });
-   *    // -> [[10.0, 20.0], [10.0, 20.0], [0.0, 30.0], [10.0, 20.0]]
-   * }
- * - * @param coordinates position in the hydrated array to start from - * @return a {@link Elements} instance - * @throws IllegalArgumentException if {@code coordinates} are set but are not one of an element of the hydrated array - */ - Elements byElements(long... coordinates); - - /** - * Creates an API to hydrate the targeted array with {@code Double} boxed type. - * - * Note that sticking to primitive types improve I/O performances overall, so only rely boxed types if the data is already - * available in that format. - * - * @return a hydrator supporting {@code Double} boxed type - */ - NdArrayHydrator boxed(); -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/NdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/NdArrayHydrator.java deleted file mode 100644 index fc64f89..0000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/hydrator/NdArrayHydrator.java +++ /dev/null @@ -1,177 +0,0 @@ -package org.tensorflow.ndarray.hydrator; - -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.buffer.DataBuffer; - -/** - * Interface for initializing the data of a {@link NdArray} that has just been allocated. - * - * While it is always possible to set the data of a read-write NdArray using standard output methods, - * like {@link NdArray#write(DataBuffer)} or {@link NdArray#copyTo(NdArray)}, the hydrator API focuses on - * sequential per-element initialization, similar to standard Java arrays. - * - * Since the hydrator API is only accessible right after the array have been allocated, it can be used to - * initialize data-sensitive arrays, like {@link org.tensorflow.ndarray.SparseNdArray}, which can be only - * written once and stay read-only thereafter. - * - * @param the type of data of the {@link NdArray} to initialize - */ -public interface NdArrayHydrator { - - /** - * An API for hydrate an {@link NdArray} using scalar values - * - * @param the type of data of the {@link NdArray} to initialize - */ - interface Scalars { - - /** - * Position the hydrator to the given {@code coordinates} to write the next scalars. - * - * @param coordinates position in the hydrated array - * @return this API - * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of a scalar - */ - Scalars at(long... coordinates); - - /** - * Set an object as the next scalar value in the hydrated array. - * - * @param scalar next scalar value - * @return this API - * @throws IllegalArgumentException if {@code scalar} is null - */ - Scalars put(T scalar); - } - - /** - * An API for hydrate an {@link NdArray} using vectors, i.e. a list of scalars - * - * @param the type of data of the {@link NdArray} to initialize - */ - interface Vectors { - - /** - * Position the hydrator to the given {@code coordinates} to write the next vectors. - * - * @param coordinates position in the hydrated array - * @return this API - * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of a vector - */ - Vectors at(long... coordinates); - - /** - * Set a list of objects as the next vector in the hydrated array. - * - * @param vector next vector values - * @return this API - * @throws IllegalArgumentException if {@code vector} is empty or its length is greater than the size of the dimension - * {@code n-1}, given {@code n} the rank of the hydrated array - */ - Vectors put(T... vector); - } - - /** - * An API for hydrate an {@link NdArray} using n-dimensional elements (sub-arrays). - * - * @param the type of data of the {@link NdArray} to initialize - */ - interface Elements { - - /** - * Position the hydrator to the given {@code coordinates} to write the next elements. - * - * @param coordinates position in the hydrated array - * @return this API - * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of an element of the hydrated array - */ - Elements at(long... coordinates); - - /** - * Set a n-dimensional array of objects as the next element in the hydrated array. - * - * @param element array containing the next element values - * @return this API - * @throws IllegalArgumentException if {@code element} is null or its shape is incompatible with the current hydrator position - */ - Elements put(NdArray element); - } - - /** - * Start to hydrate the targeted array with scalars. - * - * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, - * defaults to the first scalar of this array. - * - * Example of usage: - *
{@code
-   *    NdArray array = NdArrays.ofObjects(String.class, Shape.of(3, 2), hydrator -> {
-   *        hydrator.byScalars()
-   *          .put("Cat")
-   *          .put("Dog")
-   *          .put("House")
-   *          .at(2, 1)
-   *          .put("Apple");
-   *    });
-   *    // -> [["Cat", "Dog"], ["House", null], [null, "Apple"]]
-   * }
- * - * @param coordinates position in the hydrated array to start from - * @return a {@link Scalars} instance - * @throws IllegalArgumentException if {@code coordinates} are set but are not one of a scalar - */ - Scalars byScalars(long... coordinates); - - /** - * Start to hydrate the targeted array with vectors. - * - * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, - * defaults to the first scalar of the first vector of this array. - * - * Example of usage: - *
{@code
-   *    NdArray array = NdArrays.ofObjects(String.class, Shape.of(3, 2), hydrator -> {
-   *        hydrator.byVectors()
-   *          .put("Cat", "Dog")
-   *          .put("House")
-   *          .at(2)
-   *          .put("Orange", "Apple");
-   *    });
-   *    // -> [["Cat", "Dog"], ["House", null], ["Orange", "Apple"]]
-   * }
- * - * @param coordinates position in the hydrated array to start from - * @return a {@link Vectors} instance - * @throws IllegalArgumentException if hydrated array is of rank-0 or if {@code coordinates} are set but are not one of a vector - */ - Vectors byVectors(long... coordinates); - - /** - * Start to hydrate the targeted array with n-dimensional elements. - * - * If no {@code coordinates} are provided, the start position is the current one relatively to any previous hydration that occured or if none, - * defaults to the first element in the first (0) dimension of the hydrated array. - * - * Example of usage: - *
{@code
-   *    NdArray vector = NdArrays.vectorOfObjects("Cat", "Dog");
-   *    NdArray scalar = NdArrays.scalarOfObject("Apple");
-   *
-   *    NdArray array = NdArrays.ofObjects(String.class, Shape.of(4, 2), hydrator -> {
-   *        hydrator.byElements()
-   *          .put(vector)
-   *          .put(vector)
-   *          .at(2, 1)
-   *          .put(scalar)
-   *          .at(3)
-   *          .put(vector);
-   *    });
-   *    // -> [["Cat", "Dog"], ["Cat", "Dog"], [null, "Apple"], ["Cat", "Dog"]]
-   * }
- * - * @param coordinates position in the hydrated array to start from - * @return a {@link Elements} instance - * @throws IllegalArgumentException if {@code coordinates} are set but are not one of an element of the hydrated array - */ - Elements byElements(long... coordinates); -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/Validator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/Validator.java index 285d099..11c8453 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/Validator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/Validator.java @@ -16,11 +16,12 @@ */ package org.tensorflow.ndarray.impl; -import java.nio.BufferOverflowException; -import java.nio.BufferUnderflowException; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.buffer.DataBuffer; +import java.nio.BufferOverflowException; +import java.nio.BufferUnderflowException; + public class Validator { public static void copyToNdArrayArgs(NdArray ndArray, NdArray otherNdArray) { @@ -42,14 +43,5 @@ public static void writeFromBufferArgs(NdArray ndArray, DataBuffer src) { } } - private static void copyArrayArgs(int arrayLength, int arrayOffset) { - if (arrayOffset < 0) { - throw new IndexOutOfBoundsException("Offset must be non-negative"); - } - if (arrayOffset > arrayLength) { - throw new IndexOutOfBoundsException("Offset must be no larger than array length"); - } - } - protected Validator() {} } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydrator.java deleted file mode 100644 index d7bced0..0000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DenseNdArrayHydrator.java +++ /dev/null @@ -1,101 +0,0 @@ -package org.tensorflow.ndarray.impl.dense.hydrator; - -import java.util.Iterator; - -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.hydrator.NdArrayHydrator; -import org.tensorflow.ndarray.impl.dense.AbstractDenseNdArray; -import org.tensorflow.ndarray.impl.sequence.PositionIterator; - -public class DenseNdArrayHydrator implements NdArrayHydrator { - - public DenseNdArrayHydrator(AbstractDenseNdArray> array) { - this.array = array; - } - - @Override - public Scalars byScalars(long... coordinates) { - return new ScalarsImpl(coordinates); - } - - @Override - public Vectors byVectors(long... coordinates) { - return new VectorsImpl(coordinates); - } - - @Override - public Elements byElements(long... coordinates) { - return new ElementsImpl(coordinates); - } - - class ScalarsImpl implements Scalars { - - public Scalars at(long... coordinates) { - positionIterator = Helpers.iterateByPosition(array, 0, coordinates); - return this; - } - - @Override - public Scalars put(T scalar) { - if (scalar == null) { - throw new IllegalArgumentException("Scalar value cannot be null"); - } - array.buffer().setObject(scalar, positionIterator.nextLong()); - return this; - } - - ScalarsImpl(long[] coordinates) { - positionIterator = Helpers.iterateByPosition(array, 0, coordinates); - } - - private PositionIterator positionIterator; - } - - class VectorsImpl implements Vectors { - - @Override - public Vectors at(long... coordinates) { - positionIterator = Helpers.iterateByPosition(array, 1, coordinates); - return this; - } - - @Override - public Vectors put(T... vector) { - Helpers.validateVectorLength(vector.length, array.shape()); - array.buffer().offset(positionIterator.nextLong()).write(vector); - return this; - } - - VectorsImpl(long[] coordinates) { - positionIterator = Helpers.iterateByPosition(array, 1, coordinates); - } - - private PositionIterator positionIterator; - } - - class ElementsImpl implements Elements { - - @Override - public Elements at(long... coordinates) { - this.elementIterator = Helpers.iterateByElement(array, coordinates); - return this; - } - - @Override - public Elements put(NdArray element) { - if (element == null) { - throw new IllegalArgumentException("Element cannot be null"); - } - element.copyTo(elementIterator.next()); - return this; - } - - ElementsImpl(long[] coordinates) { - this.elementIterator = Helpers.iterateByElement(array, coordinates); - } - - private Iterator> elementIterator; - } - - private final AbstractDenseNdArray> array; -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydrator.java deleted file mode 100644 index 087bb6f..0000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydrator.java +++ /dev/null @@ -1,104 +0,0 @@ -package org.tensorflow.ndarray.impl.dense.hydrator; - -import java.util.Iterator; - -import org.tensorflow.ndarray.DoubleNdArray; -import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; -import org.tensorflow.ndarray.hydrator.NdArrayHydrator; -import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; -import org.tensorflow.ndarray.impl.sequence.PositionIterator; - -public class DoubleDenseNdArrayHydrator implements DoubleNdArrayHydrator { - - public DoubleDenseNdArrayHydrator(DoubleDenseNdArray array) { - this.array = array; - } - - @Override - public Scalars byScalars(long... coordinates) { - return new ScalarsImpl(coordinates); - } - - @Override - public Vectors byVectors(long... coordinates) { - return new VectorsImpl(coordinates); - } - - @Override - public Elements byElements(long... coordinates) { - return new ElementsImpl(coordinates); - } - - @Override - public NdArrayHydrator boxed() { - return new DenseNdArrayHydrator(array); - } - - class ScalarsImpl implements Scalars { - - public Scalars at(long... coordinates) { - positionIterator = Helpers.iterateByPosition(array, 0, coordinates); - return this; - } - - @Override - public Scalars put(double scalar) { - array.buffer().setObject(scalar, positionIterator.nextLong()); - return this; - } - - ScalarsImpl(long[] coordinates) { - positionIterator = Helpers.iterateByPosition(array, 0, coordinates); - } - - private PositionIterator positionIterator; - } - - class VectorsImpl implements Vectors { - - @Override - public Vectors at(long... coordinates) { - positionIterator = Helpers.iterateByPosition(array, 1, coordinates); - return this; - } - - @Override - public Vectors put(double... vector) { - Helpers.validateVectorLength(vector.length, array.shape()); - array.buffer().offset(positionIterator.nextLong()).write(vector); - return this; - } - - VectorsImpl(long[] coordinates) { - positionIterator = Helpers.iterateByPosition(array, 1, coordinates); - } - - private PositionIterator positionIterator; - } - - class ElementsImpl implements Elements { - - @Override - public Elements at(long... coordinates) { - this.elementIterator = Helpers.iterateByElement(array, coordinates); - return this; - } - - @Override - public Elements put(DoubleNdArray element) { - if (element == null) { - throw new IllegalArgumentException("Element cannot be null"); - } - element.copyTo(elementIterator.next()); - return this; - } - - ElementsImpl(long[] coordinates) { - this.elementIterator = Helpers.iterateByElement(array, coordinates); - } - - private Iterator elementIterator; - } - - private final DoubleDenseNdArray array; -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/Helpers.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/Helpers.java deleted file mode 100644 index 5c503c9..0000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/hydrator/Helpers.java +++ /dev/null @@ -1,52 +0,0 @@ -package org.tensorflow.ndarray.impl.dense.hydrator; - -import java.util.Arrays; -import java.util.Iterator; - -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArraySequence; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.impl.dense.AbstractDenseNdArray; -import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; -import org.tensorflow.ndarray.impl.sequence.IndexedPositionIterator; -import org.tensorflow.ndarray.impl.sequence.PositionIterator; - -final class Helpers { - - static PositionIterator iterateByPosition(AbstractDenseNdArray array, int elementRank, long[] coords) { - DimensionalSpace dimensions = array.dimensions(); - int dimensionIdx = dimensions.numDimensions() - elementRank - 1; - if (dimensionIdx < 0) { - throw new IllegalArgumentException("Cannot hydrate array of shape " + array.shape() + " with elements of rank " + elementRank); - } - if (coords == null || coords.length == 0) { - return PositionIterator.create(dimensions, dimensionIdx); - } - if ((coords.length - 1) != dimensionIdx) { - throw new IllegalArgumentException(Arrays.toString(coords) + " are not valid coordinates for dimension " - + dimensionIdx + " in an array of shape " + dimensions.shape()); - } - return PositionIterator.create(dimensions, coords); - } - - static > Iterator iterateByElement(AbstractDenseNdArray array, long[] coords) { - DimensionalSpace dimensions = array.dimensions(); - int dimensionIdx; - if (coords == null || coords.length == 0) { - return array.elements(0).iterator(); - } - if (coords.length > dimensions.numDimensions()) { - throw new IllegalArgumentException(Arrays.toString(coords) + " are not valid coordinates for an array of shape " + dimensions.shape()); - } - return array.elementsAt(coords).iterator(); - } - - static void validateVectorLength(int length, Shape shape) { - if (length == 0) { - throw new IllegalArgumentException("Vector cannot be empty"); - } - if (length > shape.get(-1)) { - throw new IllegalArgumentException("Vector cannot exceed " + shape.get(-1) + " elements"); - } - } -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/BaseDenseNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/BaseDenseNdArrayInitializer.java new file mode 100644 index 0000000..3cde108 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/BaseDenseNdArrayInitializer.java @@ -0,0 +1,134 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.dense.initializer; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.impl.dense.AbstractDenseNdArray; +import org.tensorflow.ndarray.impl.initializer.AbstractNdArrayInitializer; +import org.tensorflow.ndarray.impl.sequence.PositionIterator; +import org.tensorflow.ndarray.initializer.BaseNdArrayInitializer; + +import java.util.Collection; +import java.util.Iterator; + +abstract class BaseDenseNdArrayInitializer, V extends AbstractDenseNdArray> extends AbstractNdArrayInitializer implements BaseNdArrayInitializer { + + @Override + public Scalars byScalars(long... coordinates) { + return new ScalarsImpl(coordinates); + } + + @Override + public Elements byElements(int dimensionIdx, long... coordinates) { + return new ElementsImpl(dimensionIdx, coordinates); + } + + /** + * Per-scalar initializer for dense arrays + */ + class ScalarsImpl implements Scalars { + + public Scalars to(long... coordinates) { + jumpTo(coordinates); + positionIterator = PositionIterator.create(array.dimensions(), coords); + return this; + } + + @Override + public Scalars put(T value) { + array.buffer().setObject(value, positionIterator.nextLong()); + next(); + return this; + } + + ScalarsImpl(long[] coordinates) { + resetTo(validateRankCoords(0, coordinates)); + positionIterator = PositionIterator.create(array.dimensions(), coords); + } + + protected PositionIterator positionIterator; + } + + /** + * Per-vector initializer for dense arrays + */ + class VectorsImpl implements Vectors { + + @Override + public Vectors to(long... coordinates) { + jumpTo(coordinates); + positionIterator = PositionIterator.create(array.dimensions(), coords); + return this; + } + + @Override + public Vectors put(Collection values) { + validateVectorLength(values.size()); + for (T v : values) { + array.buffer().setObject(v, positionIterator.nextLong()); + } + next(values.size()); + return this; + } + + protected void next(int numValues) { + BaseDenseNdArrayInitializer.this.next(); + // If the number of values is less that the size of a vector, we need to reposition our iterator + if (numValues < array.shape().get(-1)) { + positionIterator = PositionIterator.create(array.dimensions(), coords); + } + } + + VectorsImpl(long[] coordinates) { + resetTo(validateRankCoords(1, coordinates)); + positionIterator = PositionIterator.create(array.dimensions(), coords); + } + + protected PositionIterator positionIterator; + } + + /** + * Per-element initializer for dense arrays. + */ + class ElementsImpl implements Elements { + + @Override + public Elements to(long... coordinates) { + jumpTo(coordinates); + elementIterator = array.elementsAt(coords).iterator(); + return this; + } + + @Override + public Elements put(NdArray values) { + values.copyTo(elementIterator.next()); + next(); + return this; + } + + ElementsImpl(int dimensionIdx, long[] coordinates) { + resetTo(validateDimensionCoords(dimensionIdx, coordinates)); + elementIterator = array.elementsAt(coords).iterator(); + } + + protected Iterator elementIterator; + } + + protected BaseDenseNdArrayInitializer(V array) { + super(array); + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/DenseNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/DenseNdArrayInitializer.java new file mode 100644 index 0000000..fb00569 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/DenseNdArrayInitializer.java @@ -0,0 +1,63 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.dense.initializer; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.impl.dense.DenseNdArray; +import org.tensorflow.ndarray.initializer.NdArrayInitializer; + +import java.util.Collection; + +public class DenseNdArrayInitializer extends BaseDenseNdArrayInitializer, DenseNdArray> implements NdArrayInitializer { + + public DenseNdArrayInitializer(DenseNdArray array) { + super(array); + } + + @Override + public NdArrayInitializer.Vectors byVectors(long... coordinates) { + return new ObjectVectorsImpl(coordinates); + } + + /** + * Per-vector initializer for dense arrays. + */ + class ObjectVectorsImpl extends VectorsImpl implements NdArrayInitializer.Vectors { + + @Override + public NdArrayInitializer.Vectors to(long... coordinates) { + return (ObjectVectorsImpl) super.to(coordinates); + } + + @Override + public NdArrayInitializer.Vectors put(Collection values) { + return (ObjectVectorsImpl) super.put(values); + } + + @Override + public NdArrayInitializer.Vectors put(T... values) { + validateVectorLength(values.length); + array.buffer().offset(positionIterator.nextLong()).write(values); + next(values.length); + return this; + } + + ObjectVectorsImpl(long[] coordinates) { + super(coordinates); + } + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/DoubleDenseNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/DoubleDenseNdArrayInitializer.java new file mode 100644 index 0000000..440668f --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/initializer/DoubleDenseNdArrayInitializer.java @@ -0,0 +1,95 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.dense.initializer; + +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; +import org.tensorflow.ndarray.initializer.DoubleNdArrayInitializer; + +import java.util.Collection; + +public class DoubleDenseNdArrayInitializer extends BaseDenseNdArrayInitializer implements DoubleNdArrayInitializer { + + public DoubleDenseNdArrayInitializer(DoubleDenseNdArray array) { + super(array); + } + + @Override + public DoubleNdArrayInitializer.Scalars byScalars(long... coordinates) { + return new DoubleScalarsImpl(coordinates); + } + + @Override + public DoubleNdArrayInitializer.Vectors byVectors(long... coordinates) { + return new DoubleVectorsImpl(coordinates); + } + + /** + * Per-scalar initializer for dense double arrays. + */ + class DoubleScalarsImpl extends ScalarsImpl implements DoubleNdArrayInitializer.Scalars { + + @Override + public DoubleNdArrayInitializer.Scalars to(long... coordinates) { + return (DoubleScalarsImpl) super.to(coordinates); + } + + @Override + public DoubleNdArrayInitializer.Scalars put(Double value) { + return (DoubleScalarsImpl) super.put(value); + } + + @Override + public DoubleNdArrayInitializer.Scalars put(double value) { + array.buffer().setDouble(value, positionIterator.nextLong()); + next(); + return this; + } + + DoubleScalarsImpl(long[] coordinates) { + super(coordinates); + } + } + + /** + * Per-vector initializer for dense double arrays. + */ + class DoubleVectorsImpl extends VectorsImpl implements DoubleNdArrayInitializer.Vectors { + + @Override + public DoubleNdArrayInitializer.Vectors to(long... coordinates) { + return (DoubleVectorsImpl) super.to(coordinates); + } + + @Override + public DoubleNdArrayInitializer.Vectors put(Collection values) { + return (DoubleVectorsImpl) super.put(values); + } + + @Override + public DoubleNdArrayInitializer.Vectors put(double... values) { + validateVectorLength(values.length); + array.buffer().offset(positionIterator.nextLong()).write(values); + next(values.length); + return this; + } + + DoubleVectorsImpl(long[] coordinates) { + super(coordinates); + } + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java index 71d1677..7a6a6f6 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java @@ -189,6 +189,15 @@ public long positionOf(long[] coords) { return position; } + public boolean increment(long[] coords) { + for (int i = coords.length - 1; i >= 0; --i) { + if ((coords[i] = (coords[i] + 1) % shape.get(i)) > 0) { + return true; + } + } + return false; + } + /** * Succinct description of the shape meant for debugging. */ diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/initializer/AbstractNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/initializer/AbstractNdArrayInitializer.java new file mode 100644 index 0000000..f4efaab --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/initializer/AbstractNdArrayInitializer.java @@ -0,0 +1,101 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.initializer; + +import org.tensorflow.ndarray.impl.AbstractNdArray; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +import java.util.Arrays; + +public abstract class AbstractNdArrayInitializer> { + + protected static long[] validateNewCoords(long[] actualCoords, long[] newCoords) { + if (actualCoords != null) { + // Make sure we always move forward + boolean smaller = false; + + int i = actualCoords.length; + while (i > newCoords.length) { + // If current coords is of a higher rank, any non-zero value for a dimension missing in the new coordinates + // requires other dimensions to be higher + if (actualCoords[--i] > 0) { + smaller = true; + } + } + while (--i >= 0) { + if (newCoords[i] < actualCoords[i]) { + smaller = true; + } else if (smaller && newCoords[i] > actualCoords[i]) { + smaller = false; + } + } + if (smaller) { + throw new IllegalArgumentException("Cannot move backward during array initialization"); + } + } + return newCoords; + } + + protected long[] validateDimensionCoords(int dimensionIdx, long[] coords) { + if (coords == null || coords.length == 0) { + return new long[dimensionIdx + 1]; + } + if ((coords.length - 1) != dimensionIdx) { + throw new IllegalArgumentException(Arrays.toString(coords) + " are not valid coordinates for dimension " + + dimensionIdx + " in an array of shape " + array.shape()); + } + return Arrays.copyOf(coords, coords.length); + } + + protected long[] validateRankCoords(int elementRank, long[] coords) { + DimensionalSpace dimensions = array.dimensions(); + int dimensionIdx = dimensions.numDimensions() - elementRank - 1; + if (dimensionIdx < 0) { + throw new IllegalArgumentException("Cannot initialize array of shape " + array.shape() + " with elements of rank " + elementRank); + } + return validateDimensionCoords(dimensionIdx, coords); + } + + protected void validateVectorLength(int numValues) { + if (numValues > array.shape().get(-1)) { + throw new IllegalArgumentException("Vector values exceeds limit of " + array.shape().get(-1) + " elements"); + } + } + + protected void next() { + array.dimensions().increment(coords); + } + + protected void jumpTo(long[] newCoordinates) { + if (coords.length != newCoordinates.length) { + throw new IllegalArgumentException("New coordinates are not for the initialized dimension"); + } + resetTo(Arrays.copyOf(newCoordinates, newCoordinates.length)); + } + + protected void resetTo(long[] newCoords) { + coords = validateNewCoords(coords, newCoords); + } + + protected final V array; + + protected long[] coords = null; // must be explicitly reset + + protected AbstractNdArrayInitializer(V array) { + this.array = array; + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java deleted file mode 100644 index a020526..0000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright 2020 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ======================================================================= - */ - -package org.tensorflow.ndarray.impl.sequence; - -import java.util.Arrays; - -public final class CoordinatesIncrementor { - - public boolean increment() { - for (int i = coords.length - 1; i >= 0; --i) { - if ((coords[i] = (coords[i] + 1) % shape[i]) > 0) { - return true; - } - } - return false; - } - - public CoordinatesIncrementor(long[] shape, int dimensionIdx) { - this.shape = shape; - this.coords = new long[dimensionIdx + 1]; - } - - public CoordinatesIncrementor(long[] shape, long[] coords) { - if (coords.length == 0 || coords.length > shape.length) { - throw new IllegalArgumentException(); - } - this.shape = shape; - this.coords = Arrays.copyOf(coords, coords.length); - } - - public final long[] shape; - public final long[] coords; -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java index c7a3dcf..39d2f67 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java @@ -27,14 +27,14 @@ class IndexedSequentialPositionIterator extends SequentialPositionIterator imple public void forEachIndexed(CoordsLongConsumer consumer) { while (hasNext()) { consumer.consume(coords, super.nextLong()); - dimensions.incrementCoordinates(coords); + dimensions.increment(coords); } } @Override public long nextLong() { long tmp = super.nextLong(); - dimensions.incrementCoordinates(coords); + dimensions.increment(coords); return tmp; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java index 7dc2f96..8f01d09 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java @@ -47,7 +47,7 @@ public void forEachIndexed(CoordsLongConsumer consumer) { } private void incrementCoords() { - if (!dimensions.incrementCoordinates(coords)) { + if (!dimensions.increment(coords)) { coords = null; } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java index a30c31c..8505112 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java @@ -17,10 +17,10 @@ package org.tensorflow.ndarray.impl.sequence; -import java.util.Arrays; -import java.util.PrimitiveIterator; import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import java.util.PrimitiveIterator; + public interface PositionIterator extends PrimitiveIterator.OfLong { static PositionIterator create(DimensionalSpace dimensions, int dimensionIdx) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java index 71332ea..69caf17 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java @@ -17,10 +17,10 @@ package org.tensorflow.ndarray.impl.sequence; -import java.util.Arrays; -import java.util.NoSuchElementException; import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import java.util.NoSuchElementException; + class SequentialPositionIterator implements PositionIterator { @Override diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java index edd087f..8e3892d 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java @@ -14,11 +14,6 @@ =======================================================================*/ package org.tensorflow.ndarray.impl.sparse; -import java.nio.ReadOnlyBufferException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.stream.LongStream; import org.tensorflow.ndarray.IllegalRankException; import org.tensorflow.ndarray.LongNdArray; import org.tensorflow.ndarray.NdArray; @@ -35,6 +30,12 @@ import org.tensorflow.ndarray.impl.sequence.SlicingElementSequence; import org.tensorflow.ndarray.index.Index; +import java.nio.ReadOnlyBufferException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.LongStream; + /** * Abstract base class for sparse array. * diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydrator.java deleted file mode 100644 index f99631c..0000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydrator.java +++ /dev/null @@ -1,126 +0,0 @@ -package org.tensorflow.ndarray.impl.sparse.hydrator; - -import org.tensorflow.ndarray.DoubleNdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; -import org.tensorflow.ndarray.hydrator.NdArrayHydrator; -import org.tensorflow.ndarray.impl.sparse.DoubleSparseNdArray; - -public class DoubleSparseNdArrayHydrator implements DoubleNdArrayHydrator { - - public DoubleSparseNdArrayHydrator(DoubleSparseNdArray array) { - this.array = array; - } - - @Override - public Scalars byScalars(long... coordinates) { - return new ScalarsImpl(coordinates); - } - - @Override - public Vectors byVectors(long... coordinates) { - return new VectorsImpl(coordinates); - } - - @Override - public Elements byElements(long... coordinates) { - return new ElementsImpl(coordinates); - } - - @Override - public NdArrayHydrator boxed() { - return new SparseNdArrayHydrator(array); - } - - private class ScalarsImpl implements Scalars { - - @Override - public Scalars at(long... coordinates) { - this.coordinates = Helpers.validateCoordinates(array, coordinates, 0); - return this; - } - - @Override - public Scalars put(double scalar) { - addValue(scalar, coordinates); - array.dimensions().incrementCoordinates(coordinates); - return this; - } - - private long[] coordinates; - - private ScalarsImpl(long[] coordinates) { - this.coordinates = Helpers.validateCoordinates(array, coordinates, 0); - } - } - - private class VectorsImpl implements Vectors { - - @Override - public Vectors at(long... coordinates) { - this.coordinates = Helpers.validateCoordinates(array, coordinates, 1); - return this; - } - - @Override - public Vectors put(double... vector) { - if (vector.length == 0 || vector.length > array.shape().get(-1)) { - throw new IllegalArgumentException("Vector cannot be null nor exceed " + array.shape().get(-1) + " elements"); - } - for (int i = 0; i < vector.length; ++i) { - addValue(vector[i], coordinates, i); - } - array.dimensions().incrementCoordinates(coordinates); - return this; - } - - private long[] coordinates; - - private VectorsImpl(long[] coordinates) { - this.coordinates = Helpers.validateCoordinates(array, coordinates, 1); - } - } - - private class ElementsImpl implements Elements { - - @Override - public Elements at(long... coordinates) { - this.coordinates = Helpers.validateCoordinates(array, coordinates); - return this; - } - - @Override - public Elements put(DoubleNdArray element) { - if (element == null) { - throw new IllegalArgumentException("Element cannot be null"); - } - if (element.shape().isScalar()) { - addValue(element.getDouble(), coordinates); - } else { - element.scalars().forEachIndexed((scalarCoords, scalar) -> { - addValue(scalar.getDouble(), coordinates, scalarCoords); - }); - } - array.dimensions().incrementCoordinates(coordinates); - return this; - } - - private long[] coordinates; - - private ElementsImpl(long[] coordinates) { - this.coordinates = Helpers.validateCoordinates(array, coordinates); - } - } - - private final DoubleSparseNdArray array; - private long valueCount = 0; - - private void addValue(double value, long[] origin, long... coords) { - if (value != array.getDefaultValue()) { - array.getValues().setDouble(value, valueCount); - Helpers.writeValueCoords(array, valueCount, origin, coords); - ++valueCount; - } - } -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/Helpers.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/Helpers.java deleted file mode 100644 index 440855f..0000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/Helpers.java +++ /dev/null @@ -1,49 +0,0 @@ -package org.tensorflow.ndarray.impl.sparse.hydrator; - -import java.util.Arrays; - -import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; -import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; - -final class Helpers { - - static long[] validateCoordinates(AbstractSparseNdArray array, long[] coords, int elementRank) { - DimensionalSpace dimensions = array.dimensions(); - int dimensionIdx = 0; - if (elementRank >= 0) { - dimensionIdx = dimensions.numDimensions() - elementRank - 1; - if (dimensionIdx < 0) { - throw new IllegalArgumentException("Cannot hydrate array of shape " + array.shape() + " with elements of rank " + elementRank); - } - } - if (coords == null || coords.length == 0) { - return new long[dimensionIdx + 1]; - } - if ((coords.length - 1) != dimensionIdx) { - throw new IllegalArgumentException(Arrays.toString(coords) + " are not valid coordinates for dimension " - + dimensionIdx + " in an array of shape " + dimensions.shape()); - } - return Arrays.copyOf(coords, coords.length); - } - - static long[] validateCoordinates(AbstractSparseNdArray array, long[] coords) { - if (coords == null || coords.length == 0) { - return new long[1]; - } - int dimensionIdx = array.shape().numDimensions() - coords.length; - if (dimensionIdx < 0) { - throw new IllegalArgumentException("Cannot hydrate array of shape " + array.shape() + " with elements of rank " + (coords.length - 1)); - } - return Arrays.copyOf(coords, coords.length); - } - - static void writeValueCoords(AbstractSparseNdArray array, long valueIndex, long[] origin, long[] coords) { - int coordsIndex = 0; - for (long c: origin) { - array.getIndices().setLong(c, valueIndex, coordsIndex++); - } - for (long c: coords) { - array.getIndices().setLong(c, valueIndex, coordsIndex++); - } - } -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydrator.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydrator.java deleted file mode 100644 index 5d183f5..0000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydrator.java +++ /dev/null @@ -1,120 +0,0 @@ -package org.tensorflow.ndarray.impl.sparse.hydrator; - -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.hydrator.NdArrayHydrator; -import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; - -public class SparseNdArrayHydrator implements NdArrayHydrator { - - public SparseNdArrayHydrator(AbstractSparseNdArray array) { - this.array = array; - } - - @Override - public Scalars byScalars(long... coordinates) { - return new ScalarsImpl(coordinates); - } - - @Override - public Vectors byVectors(long... coordinates) { - return new VectorsImpl(coordinates); - } - - @Override - public Elements byElements(long... coordinates) { - return new ElementsImpl(coordinates); - } - - private class ScalarsImpl implements Scalars { - - @Override - public Scalars at(long... coordinates) { - this.coordinates = Helpers.validateCoordinates(array, coordinates, 0); - return this; - } - - @Override - public Scalars put(T scalar) { - if (scalar == null) { - throw new IllegalArgumentException("Scalar cannot be null"); - } - if (scalar != array.getDefaultValue()) { - array.getValues().setObject(scalar, index); - array.getIndices().set(NdArrays.vectorOf(coordinates), index++); - } - array.dimensions().incrementCoordinates(coordinates); - return this; - } - - protected ScalarsImpl(long[] coordinates) { - this.coordinates = Helpers.validateCoordinates(array, coordinates, 0); - } - - protected long[] coordinates; - } - - private class VectorsImpl implements Vectors { - - @Override - public Vectors at(long... coordinates) { - this.coordinates = Helpers.validateCoordinates(array, coordinates, 1); - return this; - } - - @Override - public Vectors put(T... vector) { - if (vector.length == 0 || vector.length > array.shape().get(-1)) { - throw new IllegalArgumentException("Vector cannot be null nor exceed " + array.shape().get(-1) + " elements"); - } - for (T value : vector) { - if (value != array.getDefaultValue()) { - array.getValues().setObject(value, index); - array.getIndices().set(NdArrays.vectorOf(coordinates), index++); - } - array.dimensions().incrementCoordinates(coordinates); - } - return this; - } - - protected VectorsImpl(long[] coordinates) { - this.coordinates = Helpers.validateCoordinates(array, coordinates, 0); - } - - protected long[] coordinates; - } - - private class ElementsImpl implements Elements { - - @Override - public Elements at(long... coordinates) { - this.coordinates = Helpers.validateCoordinates(array, coordinates, coordinates.length - 1); - return this; - } - - @Override - public Elements put(NdArray element) { - if (element == null) { - throw new IllegalArgumentException("Array cannot be null"); - } - element.scalars().forEach(s -> { - T value = s.getObject(); - if (value != array.getDefaultValue()) { - array.getValues().setObject(value, index); - array.getIndices().set(NdArrays.vectorOf(coordinates), index++); - } - array.dimensions().incrementCoordinates(coordinates); - }); - return this; - } - - protected ElementsImpl(long[] coordinates) { - this.coordinates = Helpers.validateCoordinates(array, coordinates, coordinates.length - 1); - } - - protected long[] coordinates; - } - - private final AbstractSparseNdArray array; - private long index = 0; -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/BaseSparseNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/BaseSparseNdArrayInitializer.java new file mode 100644 index 0000000..b54b833 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/BaseSparseNdArrayInitializer.java @@ -0,0 +1,133 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.sparse.initializer; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.impl.initializer.AbstractNdArrayInitializer; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; +import org.tensorflow.ndarray.initializer.BaseNdArrayInitializer; + +import java.util.Arrays; +import java.util.Collection; + +abstract class BaseSparseNdArrayInitializer, V extends AbstractSparseNdArray> extends AbstractNdArrayInitializer implements BaseNdArrayInitializer { + + @Override + public Scalars byScalars(long... coordinates) { + return new ScalarsImpl(coordinates); + } + + @Override + public Elements byElements(int dimensionIdx, long... coordinates) { + return new ElementsImpl(dimensionIdx, coordinates); + } + + class ScalarsImpl implements Scalars { + + @Override + public Scalars to(long... coordinates) { + jumpTo(coordinates); + valueCoords = Arrays.copyOf(coordinates, array.shape().numDimensions()); + return this; + } + + @Override + public Scalars put(T value) { + if (value == null) { + throw new IllegalArgumentException("Scalar cannot be null"); + } + addValue(value); + next(); + return this; + } + + protected ScalarsImpl(long[] coordinates) { + resetTo(validateRankCoords(0, coordinates)); + } + } + + class VectorsImpl implements Vectors { + + @Override + public Vectors to(long... coordinates) { + jumpTo(coordinates); + valueCoords = Arrays.copyOf(coordinates, array.shape().numDimensions()); + return this; + } + + @Override + public Vectors put(Collection values) { + validateVectorLength(values.size()); + for (T value : values) { + addValue(value); + } + next(); + return this; + } + + protected VectorsImpl(long[] coordinates) { + resetTo(validateRankCoords(1, coordinates)); + } + } + + class ElementsImpl implements Elements { + + @Override + public Elements to(long... coordinates) { + jumpTo(coordinates); + valueCoords = Arrays.copyOf(coordinates, array.shape().numDimensions()); + return this; + } + + @Override + public Elements put(NdArray values) { + if (values.rank() != elementRank) { + throw new IllegalArgumentException("Values must be of element rank " + elementRank); + } + values.scalars().forEach(s -> { + addValue(s.getObject()); + }); + next(); + return this; + } + + protected ElementsImpl(int dimensionIdx, long[] coordinates) { + resetTo(validateDimensionCoords(dimensionIdx, coordinates)); + elementRank = array.shape().numDimensions() - dimensionIdx - 1; + } + + private final int elementRank; + } + + protected long valueCount = 0; + + protected long[] valueCoords; + + protected void addValue(T value) { + if (value != array.getDefaultValue()) { + array.getValues().setObject(value, valueCount); + array.getIndices().set(NdArrays.vectorOf(valueCoords), valueCount++); + } + array.dimensions().increment(valueCoords); + } + + BaseSparseNdArrayInitializer(V array) { + super(array); + this.valueCoords = new long[array.shape().numDimensions()]; + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/DoubleSparseNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/DoubleSparseNdArrayInitializer.java new file mode 100644 index 0000000..f7dcb84 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/DoubleSparseNdArrayInitializer.java @@ -0,0 +1,100 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.sparse.initializer; + +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.impl.sparse.DoubleSparseNdArray; +import org.tensorflow.ndarray.initializer.DoubleNdArrayInitializer; + +import java.util.Collection; + +public class DoubleSparseNdArrayInitializer extends BaseSparseNdArrayInitializer implements DoubleNdArrayInitializer { + + public DoubleSparseNdArrayInitializer(DoubleSparseNdArray array) { + super(array); + } + + @Override + public DoubleNdArrayInitializer.Scalars byScalars(long... coordinates) { + return new DoubleScalarsImpl(coordinates); + } + + @Override + public DoubleNdArrayInitializer.Vectors byVectors(long... coordinates) { + return new DoubleVectorsImpl(coordinates); + } + + private class DoubleScalarsImpl extends ScalarsImpl implements DoubleNdArrayInitializer.Scalars { + + @Override + public DoubleNdArrayInitializer.Scalars to(long... coordinates) { + return (DoubleScalarsImpl) super.to(coordinates); + } + + @Override + public DoubleNdArrayInitializer.Scalars put(Double value) { + return (DoubleScalarsImpl) super.put(value); + } + + @Override + public DoubleNdArrayInitializer.Scalars put(double scalar) { + addDoubleValue(scalar); + next(); + return this; + } + + private DoubleScalarsImpl(long[] coordinates) { + super(coordinates); + } + } + + private class DoubleVectorsImpl extends VectorsImpl implements DoubleNdArrayInitializer.Vectors { + + @Override + public DoubleNdArrayInitializer.Vectors to(long... coordinates) { + return (DoubleVectorsImpl) super.to(coordinates); + } + + @Override + public DoubleNdArrayInitializer.Vectors put(Collection values) { + return (DoubleVectorsImpl) super.put(values); + } + + @Override + public DoubleNdArrayInitializer.Vectors put(double... values) { + validateVectorLength(values.length); + for (double value : values) { + addDoubleValue(value); + } + next(); + return this; + } + + private DoubleVectorsImpl(long[] coordinates) { + super(coordinates); + } + } + + private void addDoubleValue(double value) { + if (value != array.getDefaultValue()) { + array.getValues().setDouble(value, valueCount); + array.getIndices().set(NdArrays.vectorOf(valueCoords), valueCount++); + } + array.dimensions().increment(valueCoords); + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/SparseNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/SparseNdArrayInitializer.java new file mode 100644 index 0000000..59182ac --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/initializer/SparseNdArrayInitializer.java @@ -0,0 +1,62 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.sparse.initializer; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; +import org.tensorflow.ndarray.initializer.NdArrayInitializer; + +import java.util.Collection; + +public class SparseNdArrayInitializer, V extends AbstractSparseNdArray> extends BaseSparseNdArrayInitializer implements NdArrayInitializer { + + public SparseNdArrayInitializer(V array) { + super(array); + } + + @Override + public NdArrayInitializer.Vectors byVectors(long... coordinates) { + return new ObjectVectorsImpl(coordinates); + } + + class ObjectVectorsImpl extends VectorsImpl implements NdArrayInitializer.Vectors { + + @Override + public NdArrayInitializer.Vectors to(long... coordinates) { + return (ObjectVectorsImpl) super.to(coordinates); + } + + @Override + public NdArrayInitializer.Vectors put(Collection values) { + return (ObjectVectorsImpl) super.put(values); + } + + @Override + public NdArrayInitializer.Vectors put(T... values) { + validateVectorLength(values.length); + for (T value : values) { + addValue(value); + } + next(); + return this; + } + + ObjectVectorsImpl(long[] coordinates) { + super(coordinates); + } + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/initializer/BaseNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/initializer/BaseNdArrayInitializer.java new file mode 100644 index 0000000..b999ff1 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/initializer/BaseNdArrayInitializer.java @@ -0,0 +1,223 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.initializer; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.buffer.DataBuffer; + +import java.util.Collection; + +/** + * Interface for initializing the data of a {@link NdArray} that has just been allocated. + * + *

The initializer API focuses on relative per-element initialization of the data of a newly allocated + * NdArray, which is more idiomatic than using output methods such as + * {@link NdArray#write(DataBuffer)} or {@link NdArray#copyTo(NdArray)}.

+ * + *

It also allows the initialization of {@link org.tensorflow.ndarray.SparseNdArray sparse arrays} before they + * become read-only.

+ * + * @param the type of data of the {@link NdArray} to initialize + */ +public interface BaseNdArrayInitializer { + + /** + * An API for initializing an {@link NdArray} using scalar values + * + * @param the type of data of the {@link NdArray} to initialize + */ + interface Scalars { + + /** + * Reset the position of the initializer in the NdArray so that the next values provided are + * written starting from the given {@code coordinates}. + * + *

Note that it is not possible to move backward within the array, {@code coordinates} must be equal or greater + * than the actual position of the initializer.

+ * + * @param coordinates position in the array + * @return this object + * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of a scalar + */ + Scalars to(long... coordinates); + + /** + * Sets the next scalar value in the array. + * + * @param value next scalar value + * @return this object + */ + Scalars put(T value); + } + + /** + * An API for initializing an {@link NdArray} using vectors. + * + * @param the type of data of the {@link NdArray} to initialize + */ + interface Vectors { + + /** + * Reset the position of the initializer in the NdArray so that the next vectors provided are written + * starting from the given {@code coordinates}. + * + *

Note that it is not possible to move backward within the array, {@code coordinates} must be equal or greater + * than the actual position of the initializer.

+ * + * @param coordinates position in the array + * @return this object + * @throws IllegalArgumentException if {@code coordinates} are empty or are not one of a vector + */ + Vectors to(long... coordinates); + + /** + * Sets the next vector values in the array. + * + * @param values next vector values + * @return this object + * @throws IllegalArgumentException if {@code vector.length > array.shape().get(-1)} + */ + Vectors put(Collection values); + } + + /** + * An API for initializing an {@link NdArray} using n-dimensional elements (sub-arrays). + * + * @param the type of data of the {@link NdArray} to initialize + */ + interface Elements { + + /** + * Reset the position of the initializer in the NdArray so that the next elements provided are written + * starting from the given {@code coordinates}. + * + *

Note that it is not possible to move backward within the array, {@code coordinates} must be equal or greater + * than the actual position of the initializer.

+ * + * @param coordinates position in the array + * @return this object + * @throws IllegalArgumentException if {@code coordinates} are empty or are of a different dimension + */ + Elements to(long... coordinates); + + /** + * Sets the next element values in the array. + * + * @param values array containing the next element values + * @return this object + * @throws IllegalArgumentException if {@code element} is null or of the wrong rank + */ + Elements put(NdArray values); + } + + /** + * Per-scalar initialization of an {@link NdArray}. + * + *

Scalar initialization writes sequentially to an NdArray each individual values provided. Position + * can be reset to any scalar, across all dimensions.

+ * + *

If no {@code coordinates} are provided, the start position is the first scalar of this array.

+ * + * Example of usage: + *
{@code
+   *    NdArray array = NdArrays.ofObjects(String.class, Shape.of(3, 2), initializer -> {
+   *        initializer.byScalars()
+   *          .put("Cat")
+   *          .put("Dog")
+   *          .put("House")
+   *          .to(2, 1)
+   *          .put("Apple");
+   *    });
+   *    // -> [["Cat", "Dog"], ["House", null], [null, "Apple"]]
+   * }
+ * + * @param coordinates position of a scalar in the array to start initialization from, none for first scalar + * @return a {@link Scalars} instance + * @throws IllegalArgumentException if {@code coordinates} are set but are not one of a scalar + */ + Scalars byScalars(long... coordinates); + + /** + * Per-vector initialization of an {@link NdArray}. + * + *

Vector initialization writes sequentially provided values to vectors at the dimension n - 1 + * of an NdArray of rank n. The NdArray must therefore be of rank + * {@code > 0} (non-scalar).

+ * + *

Like in standard Java multidimensional arrays, it is possible to initialize partially a vector + * (i.e. having a number of values {@code < array.shape().get(-1)}). + * In such case, only the first values of the vector in the array will be initialized and the remaining will be + * left untouched (mostly defaulting to 0, depending on the type of buffer used to create the NdArray). + *

+ * + *

If no {@code coordinates} are provided, the start position is the the first vector of this array.

+ * + * Example of usage: + *
{@code
+   *    NdArray array = NdArrays.ofObjects(String.class, Shape.of(3, 2), initializer -> {
+   *        initializer.byVectors()
+   *          .put("Cat", "Dog")
+   *          .put("House") // partial initialization
+   *          .to(2)
+   *          .put("Orange", "Apple");
+   *    });
+   *    // -> [["Cat", "Dog"], ["House", null], ["Orange", "Apple"]]
+   * }
+ * + * @param coordinates position of a vector in the array to start initialization from, none for first vector + * @return a {@link Vectors} instance + * @throws IllegalArgumentException if array is of rank-0 or if {@code coordinates} are set but are not one of a vector + */ + Vectors byVectors(long... coordinates); + + /** + * Per-element initialization of an {@link NdArray}. + * + *

Element initialization writes sequentially values of provided arrays to elements at the dimension + * dimensionIdx of an NdArray. The provided arrays must be all the same rank, which matches + * the rank of the elements of the NdArray elements at this dimension.

+ * + *

If no {@code coordinates} are provided, the start position is the first element of dimension + * dimensionIdx of the array.

+ * + * Example of usage: + *
{@code
+   *    NdArray matrix = StdArrays.ndCopyOf(new String[][] {
+   *      { "Cat", "Apple" }, { "Dog", "Orange" }
+   *    });
+   *
+   *    NdArray array = NdArrays.ofObjects(String.class, Shape.of(4, 2, 2), initializer -> {
+   *        initializer.byElements(0)
+   *          .put(matrix)
+   *          .put(matrix)
+   *          .to(3)
+   *          .put(matrix);
+   *    });
+   *    // -> [[["Cat", "Apple"], ["Dog", "Orange"]],
+   *    //     [["Cat", "Apple"], ["Dog", "Orange"]],
+   *    //     [[null, null], [null, null]],
+   *    //     [["Cat", "Apple"], ["Dog", "Orange"]]]
+   * }
+ * + * @param dimensionIdx the index of the dimension being initialized + * @param coordinates position in the array to start from + * @return a {@link Elements} instance + * @throws IllegalArgumentException if {@code coordinates} are set but are not one of an element of the array or + * if {@code dimensionIdx >= array.shape().numDimensions()} + */ + Elements byElements(int dimensionIdx, long... coordinates); +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/initializer/DoubleNdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/initializer/DoubleNdArrayInitializer.java new file mode 100644 index 0000000..04693f5 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/initializer/DoubleNdArrayInitializer.java @@ -0,0 +1,76 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.initializer; + +import org.tensorflow.ndarray.DoubleNdArray; + +import java.util.Collection; + +/** + * Specialization of the {@link BaseNdArrayInitializer} API for initializing arrays of doubles. + * + * @see BaseNdArrayInitializer + */ +public interface DoubleNdArrayInitializer extends BaseNdArrayInitializer { + + /** + * An API for initializing an {@link DoubleNdArray} using scalar values + */ + interface Scalars extends BaseNdArrayInitializer.Scalars { + + @Override + Scalars to(long... coordinates); + + @Override + Scalars put(Double value); + + /** + * Set the next double value in the array. + * + * @param value next scalar value + * @return this object + */ + Scalars put(double value); + } + + /** + * An API for initializing an {@link DoubleNdArray} using vectors. + */ + interface Vectors extends BaseNdArrayInitializer.Vectors { + + @Override + Vectors to(long... coordinates); + + @Override + Vectors put(Collection values); + + /** + * Set the next vector double values in the array. + * + * @param values next vector values + * @return this object + * @throws IllegalArgumentException if {@code vector.length > array.shape().get(-1)} + */ + Vectors put(double... values); + } + + @Override + Scalars byScalars(long... coordinates); + + @Override + Vectors byVectors(long... coordinates); +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/initializer/NdArrayInitializer.java b/ndarray/src/main/java/org/tensorflow/ndarray/initializer/NdArrayInitializer.java new file mode 100644 index 0000000..60346b0 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/initializer/NdArrayInitializer.java @@ -0,0 +1,55 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.initializer; + +import java.util.Collection; + +/** + * Specialization of the {@link BaseNdArrayInitializer} API for initializing arrays of objects. + * + * @see BaseNdArrayInitializer + * @param type of objects to initialize + */ +public interface NdArrayInitializer extends BaseNdArrayInitializer { + + /** + * {@inheritDoc} + */ + interface Vectors extends BaseNdArrayInitializer.Vectors { + + @Override + Vectors to(long... coordinates); + + @Override + Vectors put(Collection values); + + /** + * Set the next vector values in the array. + * + * @param values next vector values + * @return this object + * @throws IllegalArgumentException if {@code vector.length > array.shape().get(-1)} + */ + Vectors put(T... values); + } + + /** + * {@inheritDoc} + */ + @Override + Vectors byVectors(long... coordinates); +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydratorTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydratorTest.java deleted file mode 100644 index 63badf4..0000000 --- a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/hydrator/DoubleDenseNdArrayHydratorTest.java +++ /dev/null @@ -1,24 +0,0 @@ -package org.tensorflow.ndarray.impl.dense.hydrator; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import java.util.function.Consumer; - -import org.junit.jupiter.api.Test; -import org.tensorflow.ndarray.DoubleNdArray; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.StdArrays; -import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; -import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydratorTestBase; -import org.tensorflow.ndarray.hydrator.NdArrayHydrator; -import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; - -public class DoubleDenseNdArrayHydratorTest extends DoubleNdArrayHydratorTestBase { - - @Override - protected DoubleNdArray newArray(Shape shape, long numValues, Consumer hydrate) { - return NdArrays.ofDoubles(shape, hydrate); - } -} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/initializer/DoubleDenseNdArrayInitializerTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/initializer/DoubleDenseNdArrayInitializerTest.java new file mode 100644 index 0000000..ef2c5a8 --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/initializer/DoubleDenseNdArrayInitializerTest.java @@ -0,0 +1,33 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.dense.initializer; + +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.initializer.DoubleNdArrayInitializer; +import org.tensorflow.ndarray.impl.initializer.DoubleNdArrayInitializerTestBase; + +import java.util.function.Consumer; + +public class DoubleDenseNdArrayInitializerTest extends DoubleNdArrayInitializerTestBase { + + @Override + protected DoubleNdArray newArray(Shape shape, long numValues, Consumer init) { + return NdArrays.ofDoubles(shape, init); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/initializer/StringDenseNdArrayInitializerTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/initializer/StringDenseNdArrayInitializerTest.java new file mode 100644 index 0000000..2e9012a --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/initializer/StringDenseNdArrayInitializerTest.java @@ -0,0 +1,33 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.dense.initializer; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.initializer.NdArrayInitializer; +import org.tensorflow.ndarray.impl.initializer.StringNdArrayInitializerTestBase; + +import java.util.function.Consumer; + +public class StringDenseNdArrayInitializerTest extends StringNdArrayInitializerTestBase { + + @Override + protected NdArray newArray(Shape shape, long numValues, Consumer> init) { + return NdArrays.ofObjects(String.class, shape, init); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/AbstractNdArrayInitializerTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/AbstractNdArrayInitializerTest.java new file mode 100644 index 0000000..38a7ac1 --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/AbstractNdArrayInitializerTest.java @@ -0,0 +1,68 @@ +package org.tensorflow.ndarray.impl.initializer; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class AbstractNdArrayInitializerTest { + + @Test + public void newCoordsMovingForwardAreValid() { + AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0, 0}, new long[]{1, 0, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0, 0}, new long[]{0, 1, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0, 0}, new long[]{0, 0, 1}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0, 0}, new long[]{1, 0, 1}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 0, 0}, new long[]{1, 0, 1}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 1, 0}, new long[]{1, 2, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 1, 0}, new long[]{1, 1, 1}); + } + + @Test + public void newCoordsOfLowerRankMovingForwardAreValid() { + AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0, 0}, new long[]{1, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0}, new long[]{1}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 1, 0}, new long[]{1, 2}); + } + + @Test + public void newCoordsOfHigherRankMovingForwardAreValid() { + AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0}, new long[]{1, 0, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0}, new long[]{0, 1, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1}, new long[]{1, 2}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1}, new long[]{2, 0}); + } + + @Test + public void newCoordsEqualsToActualAreValid() { + AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 0, 0}, new long[]{1, 0, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 0, 0}, new long[]{1, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 0, 0}, new long[]{1}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{1}, new long[]{1, 0}); + AbstractNdArrayInitializer.validateNewCoords(new long[]{0}, new long[]{0, 0, 0}); + } + + @Test + public void newCoordsMovingBackwardAreInvalid() { + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 0, 0}, new long[]{0, 0, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 1, 0}, new long[]{0, 0, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 0, 1}, new long[]{0, 0, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 0, 1}, new long[]{0, 0, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 0, 1}, new long[]{1, 0, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 2, 0}, new long[]{1, 1, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 1, 1}, new long[]{1, 1, 0})); + } + + @Test + public void newCoordsLowerRankMovingBackwardAreInvalid() { + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{1, 0, 0}, new long[]{0, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 1, 0}, new long[]{0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{2, 2, 0}, new long[]{2})); + } + + @Test + public void newCoordsHigherRankMovingBackwardAreInvalid() { + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{1}, new long[]{0, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{0, 1}, new long[]{0, 0, 0})); + assertThrows(IllegalArgumentException.class, () -> AbstractNdArrayInitializer.validateNewCoords(new long[]{2, 2}, new long[]{2, 0, 0})); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydratorTestBase.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/DoubleNdArrayInitializerTestBase.java similarity index 55% rename from ndarray/src/test/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydratorTestBase.java rename to ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/DoubleNdArrayInitializerTestBase.java index 69b38f9..48d59c5 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/hydrator/DoubleNdArrayHydratorTestBase.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/DoubleNdArrayInitializerTestBase.java @@ -1,24 +1,40 @@ -package org.tensorflow.ndarray.hydrator; +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.fail; + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at -import java.util.function.Consumer; + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.initializer; import org.junit.jupiter.api.Test; import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.initializer.DoubleNdArrayInitializer; + +import java.util.function.Consumer; + +import static org.junit.jupiter.api.Assertions.assertEquals; -public abstract class DoubleNdArrayHydratorTestBase { +public abstract class DoubleNdArrayInitializerTestBase { - protected abstract DoubleNdArray newArray(Shape shape, long numValues, Consumer hydrate); + protected abstract DoubleNdArray newArray(Shape shape, long numValues, Consumer init); @Test - public void hydrateNdArrayByScalars() { - DoubleNdArray array = newArray(Shape.of(3, 2, 3), 14, hydrator -> { - hydrator + public void initializeNdArrayByScalars() { + DoubleNdArray array = newArray(Shape.of(3, 2, 3), 15, init -> { + init .byScalars() .put(0.0) .put(0.1) @@ -29,7 +45,7 @@ public void hydrateNdArrayByScalars() { .put(1.0) .put(1.1) .put(1.2) - .at(2, 0, 0) + .to(2, 0, 0) .put(2.0) .put(2.1) .put(2.2) @@ -44,13 +60,13 @@ public void hydrateNdArrayByScalars() { {{2.0, 2.1, 2.2}, {2.3, 2.4, 2.5}} }), array); - array = newArray(Shape.of(3, 2), 4, hydrator -> { - hydrator + array = newArray(Shape.of(3, 2), 4, init -> { + init .byScalars() .put(10.0) .put(20.0) .put(30.0) - .at(2, 1) + .to(2, 1) .put(40.0); }); @@ -58,14 +74,14 @@ public void hydrateNdArrayByScalars() { } @Test - public void hydrateNdArrayByVectors() { - DoubleNdArray array = newArray(Shape.of(3, 2, 3), 14, hydrator -> { - hydrator + public void initializeNdArrayByVectors() { + DoubleNdArray array = newArray(Shape.of(3, 2, 3), 15, init -> { + init .byVectors() .put(0.0, 0.1, 0.2) .put(0.3, 0.4, 0.5) .put(1.0, 1.1, 1.2) - .at(2, 0) + .to(2, 0) .put(2.0, 2.1, 2.2) .put(2.3, 2.4, 2.5); }); @@ -76,12 +92,12 @@ public void hydrateNdArrayByVectors() { {{2.0, 2.1, 2.2}, {2.3, 2.4, 2.5}} }), array); - array = newArray(Shape.of(3, 2), 5, hydrator -> { - hydrator + array = newArray(Shape.of(3, 2), 5, init -> { + init .byVectors() .put(10.0, 20.0) .put(30.0) - .at(2) + .to(2) .put(40.0, 50.0); }); @@ -89,27 +105,15 @@ public void hydrateNdArrayByVectors() { } @Test - public void vectorCannotBeEmpty() { - try { - newArray(Shape.of(3, 2), 1, hydrator -> hydrator.byVectors().put()); - fail(); - } catch (IllegalArgumentException e) { - // ok - } - } - - @Test - public void hydrateNdArrayByElements() { - DoubleNdArray array = newArray(Shape.of(3, 2, 3), 14, hydrator -> { - hydrator - .byElements() + public void initializeNdArrayByElements() { + DoubleNdArray array = newArray(Shape.of(3, 2, 3), 12, init -> { + init + .byElements(0) .put(StdArrays.ndCopyOf(new double[][]{ {0.0, 0.1, 0.2}, {0.3, 0.4, 0.5} })) - .at(1, 0) - .put(NdArrays.vectorOf(1.0, 1.1, 1.2)) - .at(2) + .to(2) .put(StdArrays.ndCopyOf(new double[][]{ {2.0, 2.1, 2.2}, {2.3, 2.4, 2.5} @@ -118,24 +122,27 @@ public void hydrateNdArrayByElements() { assertEquals(StdArrays.ndCopyOf(new double[][][]{ {{0.0, 0.1, 0.2}, {0.3, 0.4, 0.5}}, - {{1.0, 1.1, 1.2}, {0.0, 0.0, 0.0}}, + {{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}, {{2.0, 2.1, 2.2}, {2.3, 2.4, 2.5}} }), array); DoubleNdArray vector = NdArrays.vectorOf(10.0, 20.0); - DoubleNdArray scalar = NdArrays.scalarOf(30.0); - array = newArray(Shape.of(4, 2), 7, hydrator -> { - hydrator - .byElements() + array = newArray(Shape.of(4, 2, 2), 8, init -> { + init + .byElements(1) .put(vector) .put(vector) - .at(2, 1) - .put(scalar) - .at(3) + .put(vector) + .to(3, 1) .put(vector); }); - assertEquals(StdArrays.ndCopyOf(new double[][]{{10.0, 20.0}, {10.0, 20.0}, {0.0, 30.0}, {10.0, 20.0}}), array); + assertEquals(StdArrays.ndCopyOf(new double[][][]{ + {{10.0, 20.0}, {10.0, 20.0}}, + {{10.0, 20.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {10.0, 20.0}} + }), array); } } diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/StringNdArrayInitializerTestBase.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/StringNdArrayInitializerTestBase.java new file mode 100644 index 0000000..1bfd7dd --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/initializer/StringNdArrayInitializerTestBase.java @@ -0,0 +1,146 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.initializer; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.initializer.NdArrayInitializer; + +import java.util.function.Consumer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public abstract class StringNdArrayInitializerTestBase { + + protected abstract NdArray newArray(Shape shape, long numValues, Consumer> init); + + @Test + public void initializeNdArrayByScalars() { + var array = newArray(Shape.of(3, 2, 3), 15, init -> { + init + .byScalars() + .put("0.0") + .put("0.1") + .put("0.2") + .put("0.3") + .put("0.4") + .put("0.5") + .put("1.0") + .put("1.1") + .put("1.2") + .to(2, 0, 0) + .put("2.0") + .put("2.1") + .put("2.2") + .put("2.3") + .put("2.4") + .put("2.5"); + }); + + assertEquals(StdArrays.ndCopyOf(new String[][][]{ + {{"0.0", "0.1", "0.2"}, {"0.3", "0.4", "0.5"}}, + {{"1.0", "1.1", "1.2"}, {null, null, null}}, + {{"2.0", "2.1", "2.2"}, {"2.3", "2.4", "2.5"}} + }), array); + + array = newArray(Shape.of(3, 2), 4, init -> { + init + .byScalars() + .put("10.0") + .put("20.0") + .put("30.0") + .to(2, 1) + .put("40.0"); + }); + + assertEquals(StdArrays.ndCopyOf(new String[][]{{"10.0", "20.0"}, {"30.0", null}, {null, "40.0"}}), array); + } + + @Test + public void initializeNdArrayByVectors() { + var array = newArray(Shape.of(3, 2, 3), 15, init -> { + init + .byVectors() + .put("0.0", "0.1", "0.2").put("0.3", "0.4", "0.5") + .put("1.0", "1.1", "1.2") + .to(2, 0) + .put("2.0", "2.1", "2.2").put("2.3", "2.4", "2.5"); + }); + + assertEquals(StdArrays.ndCopyOf(new String[][][]{ + {{"0.0", "0.1", "0.2"}, {"0.3", "0.4", "0.5"}}, + {{"1.0", "1.1", "1.2"}, {null, null, null}}, + {{"2.0", "2.1", "2.2"}, {"2.3", "2.4", "2.5"}} + }), array); + + array = newArray(Shape.of(3, 2), 5, init -> { + init + .byVectors() + .put("10.0", "20.0") + .put("30.0") + .to(2) + .put("40.0", "50.0"); + }); + + assertEquals(StdArrays.ndCopyOf(new String[][]{{"10.0", "20.0"}, {"30.0", null}, {"40.0", "50.0"}}), array); + } + + @Test + public void initializeNdArrayByElements() { + var array = newArray(Shape.of(3, 2, 3), 12, init -> { + init + .byElements(0) + .put(StdArrays.ndCopyOf(new String[][]{ + {"0.0", "0.1", "0.2"}, + {"0.3", "0.4", "0.5"} + })) + .to(2) + .put(StdArrays.ndCopyOf(new String[][]{ + {"2.0", "2.1", "2.2"}, + {"2.3", "2.4", "2.5"} + })); + }); + + assertEquals(StdArrays.ndCopyOf(new String[][][]{ + {{"0.0", "0.1", "0.2"}, {"0.3", "0.4", "0.5"}}, + {{null, null, null}, {null, null, null}}, + {{"2.0", "2.1", "2.2"}, {"2.3", "2.4", "2.5"}} + }), array); + + var vector = NdArrays.vectorOfObjects("10.0", "20.0"); + + array = newArray(Shape.of(4, 2, 2), 8, init -> { + init + .byElements(1) + .put(vector) + .put(vector) + .put(vector) + .to(3, 1) + .put(vector); + }); + + assertEquals(StdArrays.ndCopyOf(new String[][][]{ + {{"10.0", "20.0"}, {"10.0", "20.0"}}, + {{"10.0", "20.0"}, {null, null}}, + {{null, null}, {null, null}}, + {{null, null}, {"10.0", "20.0"}} + }), array); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydratorTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydratorTest.java deleted file mode 100644 index d76922d..0000000 --- a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/DoubleSparseNdArrayHydratorTest.java +++ /dev/null @@ -1,23 +0,0 @@ -package org.tensorflow.ndarray.impl.sparse.hydrator; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import java.util.function.Consumer; - -import org.junit.jupiter.api.Test; -import org.tensorflow.ndarray.DoubleNdArray; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.StdArrays; -import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydrator; -import org.tensorflow.ndarray.hydrator.DoubleNdArrayHydratorTestBase; -import org.tensorflow.ndarray.hydrator.NdArrayHydrator; - -public class DoubleSparseNdArrayHydratorTest extends DoubleNdArrayHydratorTestBase { - - @Override - protected DoubleNdArray newArray(Shape shape, long numValues, Consumer hydrate) { - return NdArrays.sparseOfDoubles(shape, numValues, hydrate); - } -} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydratorTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydratorTest.java deleted file mode 100644 index 8ae36f5..0000000 --- a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/hydrator/SparseNdArrayHydratorTest.java +++ /dev/null @@ -1,85 +0,0 @@ -package org.tensorflow.ndarray.impl.sparse.hydrator; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import org.junit.jupiter.api.Test; -import org.tensorflow.ndarray.DoubleNdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.StdArrays; - -public class SparseNdArrayHydratorTest { - - @Test - public void hydrateNdArrayByScalars() { - DoubleNdArray array = NdArrays.sparseOfDoubles(15, Shape.of(3, 2, 3), hydrator -> { - hydrator - .byScalars() - .put(0.0) - .put(0.1) - .put(0.2) - .put(0.3) - .put(0.4) - .put(0.5) - .put(1.0) - .put(1.1) - .put(1.2) - .at(2, 0, 0) - .put(2.0) - .put(2.1) - .put(2.2) - .put(2.3) - .put(2.4) - .put(2.5); - }); - - assertEquals(StdArrays.ndCopyOf(new double[][][] { - {{ 0.0, 0.1, 0.2 }, { 0.3, 0.4, 0.5 }}, - {{ 1.0, 1.1, 1.2 }, { 0.0, 0.0, 0.0 }}, - {{ 2.0, 2.1, 2.2 }, { 2.3, 2.4, 2.5 }} - }), array); - } - - @Test - public void hydrateNdArrayByVectors() { - DoubleNdArray array = NdArrays.sparseOfDoubles(15, Shape.of(3, 2, 3), hydrator -> { - hydrator.byVectors() - .put(0.0, 0.1, 0.2) - .put(0.3, 0.4, 0.5) - .put(1.0, 1.1, 1.2) - .at(2, 0) - .put(2.0, 2.1, 2.2) - .put(2.3, 2.4, 2.5); - }); - - assertEquals(StdArrays.ndCopyOf(new double[][][] { - {{ 0.0, 0.1, 0.2 }, { 0.3, 0.4, 0.5 }}, - {{ 1.0, 1.1, 1.2 }, { 0.0, 0.0, 0.0 }}, - {{ 2.0, 2.1, 2.2 }, { 2.3, 2.4, 2.5 }} - }), array); - } - - @Test - public void hydrateNdArrayByElements() { - DoubleNdArray array = NdArrays.sparseOfDoubles(15, Shape.of(3, 2, 3), hydrator -> { - hydrator.byElements() - .put(StdArrays.ndCopyOf(new double[][] { - { 0.0, 0.1, 0.2 }, - { 0.3, 0.4, 0.5 } - })) - .at(1, 0) - .put(NdArrays.vectorOf(1.0, 1.1, 1.2)) - .at(2) - .put(StdArrays.ndCopyOf(new double[][] { - { 2.0, 2.1, 2.2 }, - { 2.3, 2.4, 2.5 } - })); - }); - - assertEquals(StdArrays.ndCopyOf(new double[][][] { - {{ 0.0, 0.1, 0.2 }, { 0.3, 0.4, 0.5 }}, - {{ 1.0, 1.1, 1.2 }, { 0.0, 0.0, 0.0 }}, - {{ 2.0, 2.1, 2.2 }, { 2.3, 2.4, 2.5 }} - }), array); - } -} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/initializer/DoubleSparseNdArrayInitializerTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/initializer/DoubleSparseNdArrayInitializerTest.java new file mode 100644 index 0000000..b6b57cd --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/initializer/DoubleSparseNdArrayInitializerTest.java @@ -0,0 +1,33 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.sparse.initializer; + +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.initializer.DoubleNdArrayInitializer; +import org.tensorflow.ndarray.impl.initializer.DoubleNdArrayInitializerTestBase; + +import java.util.function.Consumer; + +public class DoubleSparseNdArrayInitializerTest extends DoubleNdArrayInitializerTestBase { + + @Override + protected DoubleNdArray newArray(Shape shape, long numValues, Consumer init) { + return NdArrays.sparseOfDoubles(shape, numValues, init); + } +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/initializer/StringSparseNdArrayInitializerTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/initializer/StringSparseNdArrayInitializerTest.java new file mode 100644 index 0000000..0f2e8fa --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/initializer/StringSparseNdArrayInitializerTest.java @@ -0,0 +1,33 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray.impl.sparse.initializer; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.initializer.NdArrayInitializer; +import org.tensorflow.ndarray.impl.initializer.StringNdArrayInitializerTestBase; + +import java.util.function.Consumer; + +public class StringSparseNdArrayInitializerTest extends StringNdArrayInitializerTestBase { + + @Override + protected NdArray newArray(Shape shape, long numValues, Consumer> init) { + return NdArrays.sparseOfObjects(String.class, shape, numValues, init); + } +}