Skip to content

Added Tensor.dataToString() #272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.tensorflow;

import java.util.function.Consumer;
import org.tensorflow.internal.types.Tensors;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.Shaped;
import org.tensorflow.ndarray.buffer.ByteDataBuffer;
Expand All @@ -26,9 +27,9 @@
* A statically typed multi-dimensional array.
*
* <p>There are two categories of tensors in TensorFlow Java: {@link TType typed tensors} and
* {@link RawTensor raw tensors}. The former maps the tensor native memory to an
* n-dimensional typed data space, allowing direct I/O operations from the JVM, while the latter
* is only a reference to a native tensor allowing basic operations and flat data access.</p>
* {@link RawTensor raw tensors}. The former maps the tensor native memory to an n-dimensional typed
* data space, allowing direct I/O operations from the JVM, while the latter is only a reference to
* a native tensor allowing basic operations and flat data access.</p>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a lot of (undesirable?) format change in this PR, can you please revert those and only preserve changes related to the dataToString feature?

*
* <p><b>WARNING:</b> Resources consumed by the Tensor object <b>must</b> be explicitly freed by
* invoking the {@link #close()} method when the object is no longer needed. For example, using a
Expand All @@ -49,15 +50,15 @@ public interface Tensor extends Shaped, AutoCloseable {
* <p>The amount of memory to allocate is derived from the datatype and the shape of the tensor,
* and is left uninitialized.
*
* @param <T> the tensor type
* @param type the tensor type class
* @param <T> the tensor type
* @param type the tensor type class
* @param shape shape of the tensor
* @return an allocated but uninitialized tensor
* @throws IllegalArgumentException if elements of the given {@code type} are of variable length
* (e.g. strings)
* @throws IllegalArgumentException if {@code shape} is totally or partially
* {@link Shape#hasUnknownDimension() unknown}
* @throws IllegalStateException if tensor failed to be allocated
* @throws IllegalArgumentException if {@code shape} is totally or partially {@link
* Shape#hasUnknownDimension() unknown}
* @throws IllegalStateException if tensor failed to be allocated
*/
static <T extends TType> T of(Class<T> type, Shape shape) {
return of(type, shape, -1);
Expand All @@ -67,27 +68,27 @@ static <T extends TType> T of(Class<T> type, Shape shape) {
* Allocates a tensor of a given datatype, shape and size.
*
* <p>This method is identical to {@link #of(Class, Shape)}, except that the final size of the
* tensor can be explicitly set instead of computing it from the datatype and shape, which could be
* larger than the actual space required to store the data but not smaller.
* tensor can be explicitly set instead of computing it from the datatype and shape, which could
* be larger than the actual space required to store the data but not smaller.
*
* @param <T> the tensor type
* @param type the tensor type class
* @param <T> the tensor type
* @param type the tensor type class
* @param shape shape of the tensor
* @param size size in bytes of the tensor or -1 to compute the size from the shape
* @param size size in bytes of the tensor or -1 to compute the size from the shape
* @return an allocated but uninitialized tensor
* @see #of(Class, Shape)
* @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to
* store the tensor data
* @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given
* {@code type} are of variable length (e.g. strings)
* @throws IllegalArgumentException if {@code shape} is totally or partially
* {@link Shape#hasUnknownDimension() unknown}
* @throws IllegalStateException if tensor failed to be allocated
* @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given {@code
* type} are of variable length (e.g. strings)
* @throws IllegalArgumentException if {@code shape} is totally or partially {@link
* Shape#hasUnknownDimension() unknown}
* @throws IllegalStateException if tensor failed to be allocated
* @see #of(Class, Shape)
*/
static <T extends TType> T of(Class<T> type, Shape shape, long size) {
RawTensor tensor = RawTensor.allocate(type, shape, size);
try {
return (T)tensor.asTypedTensor();
return (T) tensor.asTypedTensor();
} catch (Exception e) {
tensor.close();
throw e;
Expand All @@ -111,16 +112,17 @@ static <T extends TType> T of(Class<T> type, Shape shape, long size) {
* <p>If {@code dataInitializer} fails and throws an exception, the allocated tensor will be
* automatically released before rethrowing the same exception.
*
* @param <T> the tensor type
* @param type the tensor type class
* @param shape shape of the tensor
* @param dataInitializer method receiving accessor to the allocated tensor data for initialization
* @param <T> the tensor type
* @param type the tensor type class
* @param shape shape of the tensor
* @param dataInitializer method receiving accessor to the allocated tensor data for
* initialization
* @return an allocated and initialized tensor
* @throws IllegalArgumentException if elements of the given {@code type} are of variable length
* (e.g. strings)
* @throws IllegalArgumentException if {@code shape} is totally or partially
* {@link Shape#hasUnknownDimension() unknown}
* @throws IllegalStateException if tensor failed to be allocated
* @throws IllegalArgumentException if {@code shape} is totally or partially {@link
* Shape#hasUnknownDimension() unknown}
* @throws IllegalStateException if tensor failed to be allocated
*/
static <T extends TType> T of(Class<T> type, Shape shape, Consumer<T> dataInitializer) {
return of(type, shape, -1, dataInitializer);
Expand All @@ -130,27 +132,31 @@ static <T extends TType> T of(Class<T> type, Shape shape, Consumer<T> dataInitia
* Allocates a tensor of a given datatype, shape and size.
*
* <p>This method is identical to {@link #of(Class, Shape, Consumer)}, except that the final
* size for the tensor can be explicitly set instead of being computed from the datatype and shape.
* size for the tensor can be explicitly set instead of being computed from the datatype and
* shape.
*
* <p>This could be useful for tensor types that stores data but also metadata in the tensor memory,
* <p>This could be useful for tensor types that stores data but also metadata in the tensor
* memory,
* such as the lookup table in a tensor of strings.
*
* @param <T> the tensor type
* @param type the tensor type class
* @param shape shape of the tensor
* @param size size in bytes of the tensor or -1 to compute the size from the shape
* @param dataInitializer method receiving accessor to the allocated tensor data for initialization
* @param <T> the tensor type
* @param type the tensor type class
* @param shape shape of the tensor
* @param size size in bytes of the tensor or -1 to compute the size from the shape
* @param dataInitializer method receiving accessor to the allocated tensor data for
* initialization
* @return an allocated and initialized tensor
* @see #of(Class, Shape, long, Consumer)
* @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to
* store the tensor data
* @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given
* {@code type} are of variable length (e.g. strings)
* @throws IllegalArgumentException if {@code shape} is totally or partially
* {@link Shape#hasUnknownDimension() unknown}
* @throws IllegalStateException if tensor failed to be allocated
* @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given {@code
* type} are of variable length (e.g. strings)
* @throws IllegalArgumentException if {@code shape} is totally or partially {@link
* Shape#hasUnknownDimension() unknown}
* @throws IllegalStateException if tensor failed to be allocated
* @see #of(Class, Shape, long, Consumer)
*/
static <T extends TType> T of(Class<T> type, Shape shape, long size, Consumer<T> dataInitializer) {
static <T extends TType> T of(Class<T> type, Shape shape, long size,
Consumer<T> dataInitializer) {
T tensor = of(type, shape, size);
try {
dataInitializer.accept(tensor);
Expand All @@ -167,18 +173,19 @@ static <T extends TType> T of(Class<T> type, Shape shape, long size, Consumer<T>
* <p>Data must have been encoded into {@code data} as per the specification of the TensorFlow <a
* href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C API</a>.
*
* @param <T> the tensor type
* @param type the tensor type class
* @param shape the tensor shape.
* @param <T> the tensor type
* @param type the tensor type class
* @param shape the tensor shape.
* @param rawData a buffer containing the tensor raw data.
* @throws IllegalArgumentException if {@code rawData} is not large enough to contain the tensor
* data
* @throws IllegalArgumentException if {@code shape} is totally or partially
* {@link Shape#hasUnknownDimension() unknown}
* @throws IllegalStateException if tensor failed to be allocated with the given parameters
* @throws IllegalArgumentException if {@code shape} is totally or partially {@link
* Shape#hasUnknownDimension() unknown}
* @throws IllegalStateException if tensor failed to be allocated with the given parameters
*/
static <T extends TType> T of(Class<T> type, Shape shape, ByteDataBuffer rawData) {
return of(type, shape, rawData.size(), t -> rawData.copyTo(t.asRawTensor().data(), rawData.size()));
return of(type, shape, rawData.size(),
t -> rawData.copyTo(t.asRawTensor().data(), rawData.size()));
}

/**
Expand All @@ -191,6 +198,33 @@ static <T extends TType> T of(Class<T> type, Shape shape, ByteDataBuffer rawData
*/
long numBytes();

/**
* Returns the String representation of elements stored in the tensor.
*
* @param options overrides the default configuration
* @return the String representation of the tensor
* @throws IllegalStateException if this is an operand of a graph
*/
default String dataToString(ToStringOptions... options) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the use of a vararg to handle the optional presence of options is a good idea. Having a second method accepting no parameter would be better.

We use vararg options in the op wrappers because we want to limit the number of methods that ending up in the *Ops classes, which is already more than a thousand. But here it's fine "duplicating" it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @karllessard

Integer maxWidth = null;
if (options != null) {
for (ToStringOptions opts : options) {
if (opts.maxWidth != null) {
maxWidth = opts.maxWidth;
}
}
}
return Tensors.toString(this, maxWidth);
}

/**
* @param maxWidth the maximum width of the output in characters ({@code null} if unlimited). This
* limit may surpassed if the first or last element are too long.
*/
static ToStringOptions maxWidth(Integer maxWidth) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this method is in the right place. Maybe move it in ToStringOptions? Also you need to describe the returned value in the javadoc or the checks might complain.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do it similar to this:
Layers.Options.create().inputShape(Shape.of(2,2))
without the ellipsis in the CTORs.

return new ToStringOptions().maxWidth(maxWidth);
}

/**
* Returns the shape of the tensor.
*/
Expand All @@ -212,4 +246,23 @@ static <T extends TType> T of(Class<T> type, Shape shape, ByteDataBuffer rawData
*/
@Override
void close();

public static class ToStringOptions {

/**
* Sets the maximum width of the output in characters.
*
* @param maxWidth the maximum width of the output in characters ({@code null} if unlimited).
* This limit may surpassed if the first or last element are too long.
*/
public ToStringOptions maxWidth(Integer maxWidth) {
this.maxWidth = maxWidth;
return this;
}

private Integer maxWidth;

private ToStringOptions() {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package org.tensorflow.internal.types;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.StringJoiner;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.Shape;

/**
* Tensor helper methods.
*/
public final class Tensors {

/**
* Prevent construction.
*/
private Tensors() {
}

/**
* Equivalent to {@link #toString(Tensor, Integer) toString(tensor, null)}.
*
* @param tensor a tensor
* @return the String representation of the tensor
*/
public static String toString(Tensor tensor) {
return toString(tensor, null);
}

/**
* @param tensor a tensor
* @param maxWidth the maximum width of the output in characters ({@code null} if unlimited). This
* limit may surpassed if the first or last element are too long.
* @return the String representation of the tensor
*/
public static String toString(Tensor tensor, Integer maxWidth) {
if (!(tensor instanceof NdArray)) {
throw new AssertionError("Expected tensor to extend NdArray");
}
NdArray<?> ndArray = (NdArray<?>) tensor;
Iterator<? extends NdArray<?>> iterator = ndArray.scalars().iterator();
Shape shape = tensor.shape();
if (shape.numDimensions() == 0) {
if (!iterator.hasNext()) {
return "";
}
return String.valueOf(iterator.next().getObject());
}
return toString(iterator, shape, 0, maxWidth);
}

/**
* @param iterator an iterator over the scalars
* @param shape the shape of the tensor
* @param maxWidth the maximum width of the output in characters ({@code null} if unlimited).
* This limit may surpassed if the first or last element are too long.
* @param dimension the current dimension being processed
* @return the String representation of the tensor data at {@code dimension}
*/
private static String toString(Iterator<? extends NdArray<?>> iterator, Shape shape,
int dimension, Integer maxWidth) {
if (dimension < shape.numDimensions() - 1) {
StringJoiner joiner = new StringJoiner(",\n", indent(dimension) + "[\n",
"\n" + indent(dimension) + "]");
for (long i = 0, size = shape.size(dimension); i < size; ++i) {
String element = toString(iterator, shape, dimension + 1, maxWidth);
joiner.add(element);
}
return joiner.toString();
}
if (maxWidth == null) {
StringJoiner joiner = new StringJoiner(", ", indent(dimension) + "[", "]");
for (long i = 0, size = shape.size(dimension); i < size; ++i) {
String element = iterator.next().getObject().toString();
joiner.add(element);
}
return joiner.toString();
}
List<Integer> lengths = new ArrayList<>();
StringJoiner joiner = new StringJoiner(", ", indent(dimension) + "[", "]");
int lengthBefore = "]".length();
for (long i = 0, size = shape.size(dimension); i < size; ++i) {
String element = iterator.next().getObject().toString();
joiner.add(element);
int addedLength = joiner.length() - lengthBefore;
lengths.add(addedLength);
lengthBefore += addedLength;
}
return truncateWidth(joiner.toString(), maxWidth, lengths);
}

/**
* @param input the input to truncate
* @param maxWidth the maximum width of the output in characters
* @param lengths the lengths of elements inside input
* @return the (potentially) truncated output
*/
private static String truncateWidth(String input, int maxWidth, List<Integer> lengths) {
if (input.length() <= maxWidth) {
return input;
}
StringBuilder output = new StringBuilder(input);
int midPoint = (maxWidth / 2) - 1;
int width = 0;
int indexOfElementToRemove = lengths.size() - 1;
int widthBeforeElementToRemove = 0;
for (int i = 0, size = lengths.size(); i < size; ++i) {
width += lengths.get(i);
if (width > midPoint) {
indexOfElementToRemove = i;
break;
}
widthBeforeElementToRemove = width;
}
if (indexOfElementToRemove == 0) {
// Cannot remove first element
return input;
}
output.insert(widthBeforeElementToRemove, ", ...");
widthBeforeElementToRemove += ", ...".length();
width = output.length();
while (width > maxWidth) {
if (indexOfElementToRemove == 0) {
// Cannot remove first element
break;
} else if (indexOfElementToRemove == lengths.size() - 1) {
// Cannot remove last element
--indexOfElementToRemove;
continue;
}
Integer length = lengths.remove(indexOfElementToRemove);
output.delete(widthBeforeElementToRemove, widthBeforeElementToRemove + length);
width = output.length();
}
if (output.length() < input.length()) {
return output.toString();
}
// Do not insert ellipses if it increases the length
return input;
}

/**
* @param level the level of indent
* @return the indentation string
*/
public static String indent(int level) {
if (level <= 0) {
return "";
}
StringBuilder result = new StringBuilder(level * 2);
for (int i = 0; i < level; ++i) {
result.append(" ");
}
return result.toString();
}
}
Loading