Skip to content

NdArray Initializers API #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ndarray/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
exports org.tensorflow.ndarray.buffer;
exports org.tensorflow.ndarray.buffer.layout;
exports org.tensorflow.ndarray.index;
exports org.tensorflow.ndarray.initializer;

// Expose all implementions of our interfaces, so consumers can write custom
// implementations easily by extending from them
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
public interface NdArraySequence<T extends NdArray<?>> extends Iterable<T> {

/**
* Visit each elements of this iteration and their respective coordinates.
* Visit each element of this iteration and their respective coordinates.
*
* <p><i>Important: the consumer method should not keep a reference to the coordinates
* as they might be mutable and reused during the iteration to improve performance.</i>
*
* @param consumer method to invoke for each elements
* @param consumer method to invoke for each element
*/
void forEachIndexed(BiConsumer<long[], T> consumer);

Expand All @@ -60,7 +60,7 @@ public interface NdArraySequence<T extends NdArray<?>> extends Iterable<T> {
* ndArray.elements(0).asSlices().forEach(e -> vectors::add); // Safe, each `e` is a distinct NdArray instance
* }</pre>
*
* @return a sequence that returns each elements iterated as a new slice
* @return a sequence that returns each element iterated as a new slice
* @see DataBufferWindow
*/
NdArraySequence<T> asSlices();
Expand Down
77 changes: 77 additions & 0 deletions ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,23 @@
import org.tensorflow.ndarray.impl.dense.IntDenseNdArray;
import org.tensorflow.ndarray.impl.dense.LongDenseNdArray;
import org.tensorflow.ndarray.impl.dense.ShortDenseNdArray;
import org.tensorflow.ndarray.impl.dense.initializer.DenseNdArrayInitializer;
import org.tensorflow.ndarray.impl.dense.initializer.DoubleDenseNdArrayInitializer;
import org.tensorflow.ndarray.impl.dimension.DimensionalSpace;
import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray;
import org.tensorflow.ndarray.impl.sparse.BooleanSparseNdArray;
import org.tensorflow.ndarray.impl.sparse.ByteSparseNdArray;
import org.tensorflow.ndarray.impl.sparse.DoubleSparseNdArray;
import org.tensorflow.ndarray.impl.sparse.FloatSparseNdArray;
import org.tensorflow.ndarray.impl.sparse.IntSparseNdArray;
import org.tensorflow.ndarray.impl.sparse.LongSparseNdArray;
import org.tensorflow.ndarray.impl.sparse.ShortSparseNdArray;
import org.tensorflow.ndarray.impl.sparse.initializer.DoubleSparseNdArrayInitializer;
import org.tensorflow.ndarray.impl.sparse.initializer.SparseNdArrayInitializer;
import org.tensorflow.ndarray.initializer.DoubleNdArrayInitializer;
import org.tensorflow.ndarray.initializer.NdArrayInitializer;

import java.util.function.Consumer;

/** Utility class for instantiating {@link NdArray} objects. */
public final class NdArrays {
Expand Down Expand Up @@ -555,6 +564,20 @@ public static DoubleNdArray ofDoubles(Shape shape) {
return wrap(shape, DataBuffers.ofDoubles(shape.size()));
}

/**
* Creates an N-dimensional array of doubles of the given shape, initializing its data after allocation.
*
* @param shape shape of the array
* @param init invoked to initialize the data of the allocated array
* @return new double N-dimensional array
* @throws IllegalArgumentException if shape is null or has unknown dimensions
*/
public static DoubleNdArray ofDoubles(Shape shape, Consumer<DoubleNdArrayInitializer> init) {
DoubleDenseNdArray array = (DoubleDenseNdArray)ofDoubles(shape);
init.accept(new DoubleDenseNdArrayInitializer(array));
return array;
}

/**
* Wraps a buffer in a double N-dimensional array of a given shape.
*
Expand All @@ -568,6 +591,23 @@ public static DoubleNdArray wrap(Shape shape, DoubleDataBuffer buffer) {
return DoubleDenseNdArray.create(buffer, shape);
}

/**
* Creates an Sparse array of doubles of the given shape, hydrating it with data after its allocation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This says "hydrating" and the one above says "initializing". Similarly for the rest of the javadoc in this class, and also the parameter names.

*
* @param shape shape of the array
* @param numValues number of double value actually set in the array, others defaulting to the zero value
* @param hydrate initialize the data of the created array, using a hydrator
* @return new double N-dimensional array
* @throws IllegalArgumentException if shape is null or has unknown dimensions
*/
public static DoubleSparseNdArray sparseOfDoubles(Shape shape, long numValues, Consumer<DoubleNdArrayInitializer> hydrate) {
LongNdArray indices = ofLongs(Shape.of(numValues, shape.numDimensions()));
DoubleNdArray values = ofDoubles(Shape.of(numValues));
DoubleSparseNdArray array = DoubleSparseNdArray.create(indices, values, DimensionalSpace.create(shape));
hydrate.accept(new DoubleSparseNdArrayInitializer(array));
return array;
}

/**
* Creates a Sparse array of double values with a default value of zero
*
Expand Down Expand Up @@ -756,6 +796,22 @@ public static <T> NdArray<T> ofObjects(Class<T> clazz, Shape shape) {
return wrap(shape, DataBuffers.ofObjects(clazz, shape.size()));
}

/**
* Creates an N-dimensional array of objects of the given shape, hydrating it with data after its allocation
*
* @param clazz class of the data to be stored in this array
* @param shape shape of the array
* @param hydrate initialize the data of the created array, using a hydrator
* @param <T> type of object to store in this array
* @return new N-dimensional array
* @throws IllegalArgumentException if shape is null or has unknown dimensions
*/
public static <T> NdArray<T> ofObjects(Class<T> clazz, Shape shape, Consumer<NdArrayInitializer<T>> hydrate) {
var array = (DenseNdArray<T>)ofObjects(clazz, shape);
hydrate.accept(new DenseNdArrayInitializer<>(array));
return array;
}

/**
* Wraps a buffer in an N-dimensional array of a given shape.
*
Expand All @@ -770,6 +826,25 @@ public static <T> NdArray<T> wrap(Shape shape, DataBuffer<T> buffer) {
return DenseNdArray.wrap(buffer, shape);
}

/**
* Creates a Sparse array of objects of the given shape, hydrating it with data after its allocation
*
* @param type the class type represented by this sparse array.
* @param shape shape of the array
* @param numValues number of values actually set in the array, others defaulting to the zero value
* @param hydrate initialize the data of the created array, using a hydrator
* @param <T> type of object to store in this array
* @return new N-dimensional array
* @throws IllegalArgumentException if shape is null or has unknown dimensions
*/
public static <T> NdArray<T> sparseOfObjects(Class<T> type, Shape shape, long numValues, Consumer<NdArrayInitializer<T>> hydrate) {
LongNdArray indices = ofLongs(Shape.of(numValues, shape.numDimensions()));
NdArray<T> values = ofObjects(type, Shape.of(numValues));
AbstractSparseNdArray<T, ?> array = (AbstractSparseNdArray<T, ?>)sparseOfObjects(type, indices, values, shape);
hydrate.accept(new SparseNdArrayInitializer<>(array));
return array;
}

/**
* Creates a Sparse array of values with a null default value
*
Expand All @@ -783,6 +858,7 @@ public static <T> NdArray<T> wrap(Shape shape, DataBuffer<T> buffer) {
* values=["one", "two"]} specifies that element {@code [1,3,1]} of the sparse NdArray has a
* value of "one", and element {@code [2,4,0]} of the NdArray has a value of "two"". All other
* values are null.
* @param <T> type of object to store in this array
* @param shape the shape of the dense array represented by this sparse array.
* @return the float sparse array.
*/
Expand All @@ -807,6 +883,7 @@ public static <T> NdArray<T> sparseOfObjects(
* values are null.
* @param defaultValue Scalar value to set for indices not specified in 'indices'
* @param shape the shape of the dense array represented by this sparse array.
* @param <T> type of object to store in this array
* @return the float sparse array.
*/
public static <T> NdArray<T> sparseOfObjects(
Expand Down
14 changes: 3 additions & 11 deletions ndarray/src/main/java/org/tensorflow/ndarray/impl/Validator.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
*/
package org.tensorflow.ndarray.impl;

import java.nio.BufferOverflowException;
import java.nio.BufferUnderflowException;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.buffer.DataBuffer;

import java.nio.BufferOverflowException;
import java.nio.BufferUnderflowException;

public class Validator {

public static void copyToNdArrayArgs(NdArray<?> ndArray, NdArray<?> otherNdArray) {
Expand All @@ -42,14 +43,5 @@ public static void writeFromBufferArgs(NdArray<?> ndArray, DataBuffer<?> src) {
}
}

private static void copyArrayArgs(int arrayLength, int arrayOffset) {
if (arrayOffset < 0) {
throw new IndexOutOfBoundsException("Offset must be non-negative");
}
if (arrayOffset > arrayLength) {
throw new IndexOutOfBoundsException("Offset must be no larger than array length");
}
}

protected Validator() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@
@SuppressWarnings("unchecked")
public abstract class AbstractDenseNdArray<T, U extends NdArray<T>> extends AbstractNdArray<T, U> {

abstract public DataBuffer<T> buffer();

public NdArraySequence<U> elementsAt(long[] startCoords) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The semantics of this method are still confusing to me. It doesn't check that startCoords is in bounds and the way in which it interacts with the starting position is odd (as I don't see why iteration from a particular starting co-ordinate should have fewer dimensions than the thing you're iterating).

DimensionalSpace elemDims = dimensions().from(startCoords.length);
try {
DataBufferWindow<? extends DataBuffer<T>> elemWindow = buffer().window(elemDims.physicalSize());
U element = instantiate(elemWindow.buffer(), elemDims);
return new FastElementSequence<T, U>(this, startCoords, element, elemWindow);
} catch (UnsupportedOperationException e) {
// If buffer windows are not supported, fallback to slicing (and slower) sequence
return new SlicingElementSequence<T, U>(this, startCoords, elemDims);
}
}

@Override
public NdArraySequence<U> elements(int dimensionIdx) {
if (dimensionIdx >= shape().numDimensions()) {
Expand All @@ -40,15 +54,7 @@ public NdArraySequence<U> elements(int dimensionIdx) {
if (rank() == 0 && dimensionIdx < 0) {
return new SingleElementSequence<>(this);
}
DimensionalSpace elemDims = dimensions().from(dimensionIdx + 1);
try {
DataBufferWindow<? extends DataBuffer<T>> elemWindow = buffer().window(elemDims.physicalSize());
U element = instantiate(elemWindow.buffer(), elemDims);
return new FastElementSequence(this, dimensionIdx, element, elemWindow);
} catch (UnsupportedOperationException e) {
// If buffer windows are not supported, fallback to slicing (and slower) sequence
return new SlicingElementSequence<>(this, dimensionIdx, elemDims);
}
return elementsAt(new long[dimensionIdx + 1]);
}

@Override
Expand Down Expand Up @@ -145,8 +151,6 @@ protected AbstractDenseNdArray(DimensionalSpace dimensions) {
super(dimensions);
}

abstract protected DataBuffer<T> buffer();

abstract U instantiate(DataBuffer<T> buffer, DimensionalSpace dimensions);

long positionOf(long[] coords, boolean isValue) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public static BooleanNdArray create(BooleanDataBuffer buffer, Shape shape) {
return new BooleanDenseNdArray(buffer, shape);
}

@Override
public BooleanDataBuffer buffer() {
return buffer;
}

@Override
public boolean getBoolean(long... indices) {
return buffer.getBoolean(positionOf(indices, true));
Expand Down Expand Up @@ -77,11 +82,6 @@ BooleanDenseNdArray instantiate(DataBuffer<Boolean> buffer, DimensionalSpace dim
return new BooleanDenseNdArray((BooleanDataBuffer)buffer, dimensions);
}

@Override
protected BooleanDataBuffer buffer() {
return buffer;
}

private final BooleanDataBuffer buffer;

private BooleanDenseNdArray(BooleanDataBuffer buffer, DimensionalSpace dimensions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public static ByteNdArray create(ByteDataBuffer buffer, Shape shape) {
return new ByteDenseNdArray(buffer, shape);
}

@Override
public ByteDataBuffer buffer() {
return buffer;
}

@Override
public byte getByte(long... indices) {
return buffer.getByte(positionOf(indices, true));
Expand Down Expand Up @@ -77,11 +82,6 @@ ByteDenseNdArray instantiate(DataBuffer<Byte> buffer, DimensionalSpace dimension
return new ByteDenseNdArray((ByteDataBuffer)buffer, dimensions);
}

@Override
protected ByteDataBuffer buffer() {
return buffer;
}

private final ByteDataBuffer buffer;

private ByteDenseNdArray(ByteDataBuffer buffer, DimensionalSpace dimensions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ DenseNdArray<T> instantiate(DataBuffer<T> buffer, DimensionalSpace dimensions) {
}

@Override
protected DataBuffer<T> buffer() {
public DataBuffer<T> buffer() {
return buffer;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public static DoubleNdArray create(DoubleDataBuffer buffer, Shape shape) {
return new DoubleDenseNdArray(buffer, shape);
}

@Override
public DoubleDataBuffer buffer() {
return buffer;
}

@Override
public double getDouble(long... indices) {
return buffer.getDouble(positionOf(indices, true));
Expand Down Expand Up @@ -77,11 +82,6 @@ DoubleDenseNdArray instantiate(DataBuffer<Double> buffer, DimensionalSpace dimen
return new DoubleDenseNdArray((DoubleDataBuffer)buffer, dimensions);
}

@Override
protected DoubleDataBuffer buffer() {
return buffer;
}

private final DoubleDataBuffer buffer;

private DoubleDenseNdArray(DoubleDataBuffer buffer, DimensionalSpace dimensions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public static FloatNdArray create(FloatDataBuffer buffer, Shape shape) {
return new FloatDenseNdArray(buffer, shape);
}

@Override
public FloatDataBuffer buffer() {
return buffer;
}

@Override
public float getFloat(long... indices) {
return buffer.getFloat(positionOf(indices, true));
Expand Down Expand Up @@ -77,11 +82,6 @@ FloatDenseNdArray instantiate(DataBuffer<Float> buffer, DimensionalSpace dimensi
return new FloatDenseNdArray((FloatDataBuffer) buffer, dimensions);
}

@Override
public FloatDataBuffer buffer() {
return buffer;
}

private final FloatDataBuffer buffer;

private FloatDenseNdArray(FloatDataBuffer buffer, DimensionalSpace dimensions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public static IntNdArray create(IntDataBuffer buffer, Shape shape) {
return new IntDenseNdArray(buffer, shape);
}

@Override
public IntDataBuffer buffer() {
return buffer;
}

@Override
public int getInt(long... indices) {
return buffer.getInt(positionOf(indices, true));
Expand Down Expand Up @@ -77,11 +82,6 @@ IntDenseNdArray instantiate(DataBuffer<Integer> buffer, DimensionalSpace dimensi
return new IntDenseNdArray((IntDataBuffer)buffer, dimensions);
}

@Override
protected IntDataBuffer buffer() {
return buffer;
}

private final IntDataBuffer buffer;

private IntDenseNdArray(IntDataBuffer buffer, DimensionalSpace dimensions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public static LongNdArray create(LongDataBuffer buffer, Shape shape) {
return new LongDenseNdArray(buffer, shape);
}

@Override
public LongDataBuffer buffer() {
return buffer;
}

@Override
public long getLong(long... indices) {
return buffer.getLong(positionOf(indices, true));
Expand Down Expand Up @@ -77,11 +82,6 @@ LongDenseNdArray instantiate(DataBuffer<Long> buffer, DimensionalSpace dimension
return new LongDenseNdArray((LongDataBuffer)buffer, dimensions);
}

@Override
protected LongDataBuffer buffer() {
return buffer;
}

private final LongDataBuffer buffer;

private LongDenseNdArray(LongDataBuffer buffer, DimensionalSpace dimensions) {
Expand Down
Loading