diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java index d3094b5e9e9..ca8a959d667 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java @@ -42,6 +42,7 @@ public Constraint(Ops tf) { * * @param weights the weights * @return the constrained weights + * @param the date type for the weights and return value */ public abstract Operand call(Operand weights); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java index 7ac73f616e2..e221d829ec9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/Dataset.java @@ -23,16 +23,16 @@ import org.tensorflow.framework.data.impl.TakeDataset; import org.tensorflow.framework.data.impl.TensorSliceDataset; import org.tensorflow.framework.data.impl.TextLineDataset; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.ndarray.Shape; +import org.tensorflow.types.family.TType; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; import java.util.function.Function; -import org.tensorflow.types.family.TType; /** * Represents a potentially large list of independent elements (samples), and allows iteration and @@ -40,12 +40,25 @@ */ public abstract class Dataset implements Iterable>> { protected Ops tf; - private Operand variant; - private List> outputTypes; - private List outputShapes; + private final Operand variant; + private final List> outputTypes; + private final List outputShapes; + /** + * Creates a Dataset + * + * @param tf The TensorFlow Ops + * @param variant the Operand that represents the dataset. + * @param outputTypes A list of classes corresponding to the tensor type of each component of a + * dataset element. + * @param outputShapes A list of `Shape` objects corresponding to the shapes of each component of + * a dataset element. + */ public Dataset( - Ops tf, Operand variant, List> outputTypes, List outputShapes) { + Ops tf, + Operand variant, + List> outputTypes, + List outputShapes) { if (tf == null) { throw new IllegalArgumentException("Ops accessor cannot be null."); } @@ -61,6 +74,11 @@ public Dataset( this.outputShapes = outputShapes; } + /** + * Creates a dataset from another dataset + * + * @param other the other dataset + */ protected Dataset(Dataset other) { this.tf = other.tf; this.variant = other.variant; @@ -68,6 +86,53 @@ protected Dataset(Dataset other) { this.outputShapes = other.outputShapes; } + /** + * Creates an in-memory `Dataset` whose elements are slices of the given tensors. Each element of + * this dataset will be a {@code List>}, representing slices (e.g. batches) of the + * provided tensors. + * + * @param tf Ops Accessor + * @param tensors A list of {@code Operand} representing components of this dataset (e.g. + * features, labels) + * @param outputTypes A list of tensor type classes representing the data type of each component + * of this dataset. + * @return A new `Dataset` + */ + public static Dataset fromTensorSlices( + Ops tf, List> tensors, List> outputTypes) { + return new TensorSliceDataset(tf, tensors, outputTypes); + } + + /** + * Creates a Dataset comprising records from one or more TFRecord files. + * + * @param tf the TensorFlow Ops + * @param filename the name of the file containing the TFRecords + * @param compressionType the compression type, either "" (no compression), "ZLIB", or "GZIP" + * @param bufferSize the number of bytes in the read buffer + * @return A Dataset comprising records from a TFRecord file. + */ + public static Dataset tfRecordDataset( + Ops tf, String filename, String compressionType, long bufferSize) { + return new TFRecordDataset( + tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize)); + } + + /** + * Creates a Dataset comprising lines from one or more text files. + * + * @param tf the TensorFlow Ops + * @param filename the name of the file containing the text linea + * @param compressionType the compression type, either "" (no compression), "ZLIB", or "GZIP" + * @param bufferSize the number of bytes in the read buffer + * @return A Dataset comprising lines from a text file. + */ + public static Dataset textLineDataset( + Ops tf, String filename, String compressionType, long bufferSize) { + return new TextLineDataset( + tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize)); + } + /** * Groups elements of this dataset into batches. * @@ -127,11 +192,12 @@ public final Dataset take(long count) { * Returns a new Dataset which maps a function across all elements from this dataset, on a single * component of each element. * - *

For example, suppose each element is a {@code List>} with 2 components: (features, - * labels). + *

For example, suppose each element is a {@code List>} with 2 components: + * (features, labels). * - *

Calling {@code dataset.mapOneComponent(0, features -> tf.math.mul(features, tf.constant(2)))} will - * map the function over the `features` component of each element, multiplying each by 2. + *

Calling {@code dataset.mapOneComponent(0, features -> tf.math.mul(features, + * tf.constant(2)))} will map the function over the `features` component of each element, + * multiplying each by 2. * * @param index The index of the component to transform. * @param mapper The function to apply to the target component. @@ -150,8 +216,8 @@ public Dataset mapOneComponent(int index, Function, Operand> mappe * Returns a new Dataset which maps a function across all elements from this dataset, on all * components of each element. * - *

For example, suppose each element is a {@code List>} with 2 components: (features, - * labels). + *

For example, suppose each element is a {@code List>} with 2 components: + * (features, labels). * *

Calling {@code dataset.mapAllComponents(component -> tf.math.mul(component, * tf.constant(2)))} will map the function over the both the `features` and `labels` components of @@ -172,8 +238,8 @@ public Dataset mapAllComponents(Function, Operand> mapper) { /** * Returns a new Dataset which maps a function over all elements returned by this dataset. * - *

For example, suppose each element is a {@code List>} with 2 components: (features, - * labels). + *

For example, suppose each element is a {@code List>} with 2 components: + * (features, labels). * *

Calling * @@ -254,53 +320,42 @@ public DatasetIterator makeOneShotIterator() { } /** - * Creates an in-memory `Dataset` whose elements are slices of the given tensors. Each element of - * this dataset will be a {@code List>}, representing slices (e.g. batches) of the - * provided tensors. + * Gets the variant tensor representing this dataset. * - * @param tf Ops Accessor - * @param tensors A list of {@code Operand} representing components of this dataset (e.g. - * features, labels) - * @param outputTypes A list of tensor type classes representing the data type of each component of - * this dataset. - * @return A new `Dataset` + * @return the variant tensor representing this dataset. */ - public static Dataset fromTensorSlices( - Ops tf, List> tensors, List> outputTypes) { - return new TensorSliceDataset(tf, tensors, outputTypes); - } - - public static Dataset tfRecordDataset( - Ops tf, String filename, String compressionType, long bufferSize) { - return new TFRecordDataset( - tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize)); - } - - public static Dataset textLineDataset( - Ops tf, String filename, String compressionType, long bufferSize) { - return new TextLineDataset( - tf, tf.constant(filename), tf.constant(compressionType), tf.constant(bufferSize)); - } - - /** Get the variant tensor representing this dataset. */ public Operand getVariant() { return variant; } - /** Get a list of output types for each component of this dataset. */ + /** + * Gets a list of output types for each component of this dataset. + * + * @return the list of output types for each component of this dataset. + */ public List> getOutputTypes() { return this.outputTypes; } - /** Get a list of shapes for each component of this dataset. */ + /** + * Gets a list of shapes for each component of this dataset. + * + * @return the list of shapes for each component of this dataset. + */ public List getOutputShapes() { return this.outputShapes; } + /** + * Gets the TensorFlow Ops Instance + * + * @return the TensorFlow Ops Instance + */ public Ops getOpsInstance() { return this.tf; } + /** {@inheritDoc} */ @Override public String toString() { return "Dataset{" diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java index a3aa290a8c8..95ba4973aac 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java @@ -17,14 +17,14 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.ndarray.Shape; +import org.tensorflow.types.family.TType; import java.util.ArrayList; import java.util.Iterator; import java.util.List; -import org.tensorflow.types.family.TType; /** * Represents the state of an iteration through a tf.data Datset. DatasetIterator is not a @@ -102,21 +102,21 @@ public class DatasetIterator implements Iterable>> { public static final String EMPTY_SHARED_NAME = ""; protected Ops tf; - - private Operand iteratorResource; - private Op initializer; - protected List> outputTypes; protected List outputShapes; + private final Operand iteratorResource; + private Op initializer; /** + * Creates a DatasetIterator + * * @param tf Ops accessor corresponding to the same `ExecutionEnvironment` as the * `iteratorResource`. * @param iteratorResource An Operand representing the iterator (e.g. constructed from * `tf.data.iterator` or `tf.data.anonymousIterator`) * @param initializer An `Op` that should be run to initialize this iterator - * @param outputTypes A list of classes corresponding to the tensor type of each component of - * a dataset element. + * @param outputTypes A list of classes corresponding to the tensor type of each component of a + * dataset element. * @param outputShapes A list of `Shape` objects corresponding to the shapes of each component of * a dataset element. */ @@ -134,6 +134,18 @@ public DatasetIterator( this.outputShapes = outputShapes; } + /** + * Creates a DatasetIterator + * + * @param tf Ops accessor corresponding to the same `ExecutionEnvironment` as the + * `iteratorResource`. + * @param iteratorResource An Operand representing the iterator (e.g. constructed from + * `tf.data.iterator` or `tf.data.anonymousIterator`) + * @param outputTypes A list of classes corresponding to the tensor type of each component of a + * dataset element. + * @param outputShapes A list of `Shape` objects corresponding to the shapes of each component of + * a dataset element. + */ public DatasetIterator( Ops tf, Operand iteratorResource, @@ -145,6 +157,11 @@ public DatasetIterator( this.outputShapes = outputShapes; } + /** + * Creates a DatasetIterator from another DatasetIterator + * + * @param other the other DatasetIterator + */ protected DatasetIterator(DatasetIterator other) { this.tf = other.tf; this.iteratorResource = other.iteratorResource; @@ -153,6 +170,26 @@ protected DatasetIterator(DatasetIterator other) { this.outputShapes = other.outputShapes; } + /** + * Creates a new iterator from a "structure" defined by `outputShapes` and `outputTypes`. + * + * @param tf Ops accessor + * @param outputTypes A list of classes repesenting the tensor type of each component of a dataset + * element. + * @param outputShapes A list of Shape objects representing the shape of each component of a + * dataset element. + * @return A new DatasetIterator + */ + public static DatasetIterator fromStructure( + Ops tf, List> outputTypes, List outputShapes) { + Operand iteratorResource = + tf.scope().env() instanceof Graph + ? tf.data.iterator(EMPTY_SHARED_NAME, "", outputTypes, outputShapes) + : tf.data.anonymousIterator(outputTypes, outputShapes).handle(); + + return new DatasetIterator(tf, iteratorResource, outputTypes, outputShapes); + } + /** * Returns a list of {@code Operand} representing the components of the next dataset element. * @@ -226,37 +263,33 @@ public Op makeInitializer(Dataset dataset) { } /** - * Creates a new iterator from a "structure" defined by `outputShapes` and `outputTypes`. + * Gets the iteratorResource * - * @param tf Ops accessor - * @param outputTypes A list of classes repesenting the tensor type of each component of a - * dataset element. - * @param outputShapes A list of Shape objects representing the shape of each component of a - * dataset element. - * @return A new DatasetIterator + * @return the iteratorResource */ - public static DatasetIterator fromStructure( - Ops tf, List> outputTypes, List outputShapes) { - Operand iteratorResource = - tf.scope().env() instanceof Graph - ? tf.data.iterator(EMPTY_SHARED_NAME, "", outputTypes, outputShapes) - : tf.data.anonymousIterator(outputTypes, outputShapes).handle(); - - return new DatasetIterator(tf, iteratorResource, outputTypes, outputShapes); - } - public Operand getIteratorResource() { return iteratorResource; } + /** + * Gets the initializer + * + * @return the initializer + */ public Op getInitializer() { return initializer; } + /** + * Gets the TensorFlow Ops Instance + * + * @return the TensorFlow Ops Instance + */ public Ops getOpsInstance() { return tf; } + /** {@inheritDoc} */ @Override public Iterator>> iterator() { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java index 6617c33eaf7..8b394ce4282 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetOptional.java @@ -16,13 +16,13 @@ package org.tensorflow.framework.data; import org.tensorflow.Operand; -import org.tensorflow.op.Ops; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TType; import java.util.ArrayList; import java.util.List; -import org.tensorflow.types.family.TType; /** * An optional represents the result of a dataset getNext operation that may fail, when the end of @@ -30,23 +30,35 @@ */ public class DatasetOptional { protected Ops tf; - - public Operand getOptionalVariant() { - return optionalVariant; - } - - private Operand optionalVariant; - private List> outputTypes; - private List outputShapes; - + private final Operand optionalVariant; + private final List> outputTypes; + private final List outputShapes; + /** + * Creates a DatasetOptional + * + * @param tf the TensorFlow Ops + * @param optionalVariant the optional Operand that represents the dataset. + * @param outputTypes A list of classes corresponding to the tensor type of each component of a + * dataset element. + * @param outputShapes A list of `Shape` objects corresponding to the shapes of each component of + * a dataset element. + */ public DatasetOptional( - Ops tf, Operand optionalVariant, List> outputTypes, List outputShapes) { + Ops tf, + Operand optionalVariant, + List> outputTypes, + List outputShapes) { this.tf = tf; this.optionalVariant = optionalVariant; this.outputTypes = outputTypes; this.outputShapes = outputShapes; } + /** + * Creates a DatasetOptional from another DatasetOptional + * + * @param other the other DatasetOptional + */ protected DatasetOptional(DatasetOptional other) { this.tf = other.tf; this.optionalVariant = other.optionalVariant; @@ -54,14 +66,44 @@ protected DatasetOptional(DatasetOptional other) { this.outputShapes = other.outputShapes; } + /** + * Creates a DatasetOptional from a list of components + * + * @param tf the TensorFlow Ops + * @param components the components + * @param outputTypes A list of classes corresponding to the tensor type of each component of a + * dataset element. + * @param outputShapes A list of `Shape` objects corresponding to the shapes of each component of + * a dataset element. + * @return a DatasetOptional initialized with the components + */ + public static DatasetOptional fromComponents( + Ops tf, + List> components, + List> outputTypes, + List outputShapes) { + Operand optionalVariant = tf.data.optionalFromValue(components); + return new DatasetOptional(tf, optionalVariant, outputTypes, outputShapes); + } + public Operand getOptionalVariant() { + return optionalVariant; + } - /** Whether this optional has a value. */ + /** + * Gets the operand indicating whether this optional has a value. + * + * @return the operand indicating whether this optional has a value. + */ public Operand hasValue() { return tf.data.optionalHasValue(optionalVariant).hasValue(); } - /** Returns the value of the dataset element represented by this optional, if it exists. */ + /** + * Gets the value of the dataset element represented by this optional, if it exists. + * + * @return the value of the dataset element represented by this optional, if it exists. + */ public List> getValue() { List> components = new ArrayList<>(); tf.data @@ -72,15 +114,6 @@ public List> getValue() { return components; } - public static DatasetOptional fromComponents( - Ops tf, - List> components, - List> outputTypes, - List outputShapes) { - Operand optionalVariant = tf.data.optionalFromValue(components); - return new DatasetOptional(tf, optionalVariant, outputTypes, outputShapes); - } - public Ops getOpsInstance() { return tf; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/BatchDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/BatchDataset.java index f0561b2e61e..ed563f59c69 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/BatchDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/BatchDataset.java @@ -17,16 +17,31 @@ import org.tensorflow.Operand; import org.tensorflow.framework.data.Dataset; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Constant; -import org.tensorflow.ndarray.Shape; import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; import java.util.List; -import org.tensorflow.types.family.TType; +/** Combines consecutive elements of a dataset into batches. */ public class BatchDataset extends Dataset { + + /** + * Creates a batched dataset + * + * @param tf The TensorFlow Ops. + * @param variant the Operand that represents the dataset. + * @param batchSize The number of desired elements per batch + * @param dropRemainder Whether to leave out the final batch if it has fewer than `batchSize` * + * elements. + * @param outputTypes A list of classes corresponding to the tensor type of each component of a + * dataset element. + * @param outputShapes A list of `Shape` objects corresponding to the shapes of each component of + * a dataset element. + */ public BatchDataset( Ops tf, Operand variant, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/MapDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/MapDataset.java index 18fb49173e6..82cfc458d5c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/MapDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/MapDataset.java @@ -22,19 +22,28 @@ import java.util.List; import java.util.function.Function; +/** Maps the elements in one dataset to another using mapper Operands. */ public class MapDataset extends Dataset { private final Function>, List>> mapper; + /** + * Creates a MapDataset from another dataset based on the mapper operations + * + * @param other the other dataset + * @param mapper the mapper operations + */ public MapDataset(Dataset other, Function>, List>> mapper) { super(other); this.mapper = mapper; } + /** {@inheritDoc} */ @Override public DatasetIterator makeOneShotIterator() { return new MapIterator(super.makeOneShotIterator(), mapper); } + /** {@inheritDoc} */ @Override public DatasetIterator makeInitializeableIterator() { return new MapIterator(super.makeInitializeableIterator(), mapper); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/MapIterator.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/MapIterator.java index 2e494066e5f..3c9677ebf26 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/MapIterator.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/MapIterator.java @@ -18,27 +18,32 @@ import org.tensorflow.Operand; import org.tensorflow.framework.data.DatasetIterator; import org.tensorflow.framework.data.DatasetOptional; -import org.tensorflow.op.Ops; import java.util.List; -import java.util.function.BiFunction; import java.util.function.Function; +/** A dataset iterator that applies mapper operands across the elements of a dataset. */ public class MapIterator extends DatasetIterator { private final Function>, List>> mapper; - public MapIterator( - DatasetIterator source, - Function>, List>> mapper) { + /** + * Creates a MapIterator + * + * @param source the data source iterator to apply the mapper operands + * @param mapper the mapper operands + */ + public MapIterator(DatasetIterator source, Function>, List>> mapper) { super(source); this.mapper = mapper; } + /** {@inheritDoc} */ @Override public List> getNext() { return mapper.apply(super.getNext()); } + /** {@inheritDoc} */ @Override public DatasetOptional getNextAsOptional() { DatasetOptional optional = super.getNextAsOptional(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/MapOptional.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/MapOptional.java index d8f9d3b6924..c8aec3ec72b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/MapOptional.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/MapOptional.java @@ -21,14 +21,24 @@ import java.util.List; import java.util.function.Function; +/** + * An optional represents the result of a MapDataset getNext operation that may fail, when the end + * of the dataset has been reached. + */ public class MapOptional extends DatasetOptional { private final Function>, List>> mapper; + /** + * Creates a MapOptional + * + * @param optional The source dataset optional + * @param mapper the mapper Operands + */ MapOptional(DatasetOptional optional, Function>, List>> mapper) { super(optional); this.mapper = mapper; } - + /** {@inheritDoc} */ @Override public List> getValue() { return mapper.apply(super.getValue()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/SkipDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/SkipDataset.java index 63b4208480b..3cf06d387b7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/SkipDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/SkipDataset.java @@ -17,16 +17,29 @@ import org.tensorflow.Operand; import org.tensorflow.framework.data.Dataset; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Constant; -import org.tensorflow.ndarray.Shape; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; import java.util.List; -import org.tensorflow.types.family.TType; +/** A Dataset that skips count elements from this dataset. */ public class SkipDataset extends Dataset { + /** + * Creates a Dataset that skips count elements from this dataset. + * + * @param tf The TensorFlow Ops + * @param variant the Operand that represents the dataset. + * @param count the number of elements of this dataset that should be skipped to form the new + * dataset. + * @param outputTypes A list of classes corresponding to the tensor type of each component of a + * dataset element. + * @param outputShapes A list of `Shape` objects corresponding to the shapes of each component of + * a dataset element. + */ public SkipDataset( Ops tf, Operand variant, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TFRecordDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TFRecordDataset.java index 00297152e90..b6aa59ee4fa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TFRecordDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TFRecordDataset.java @@ -17,15 +17,24 @@ import org.tensorflow.Operand; import org.tensorflow.framework.data.Dataset; -import org.tensorflow.op.Ops; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import java.util.Collections; +/** A Dataset comprising records from one or more TFRecord files. */ public class TFRecordDataset extends Dataset { + /** + * Creates a Dataset comprising records from one or more TFRecord files. + * + * @param tf the TensorFlow Ops + * @param filenames the names of one or more files containing TFRecords + * @param compressionType the compression type, either "" (no compression), "ZLIB", or "GZIP" + * @param bufferSize the number of bytes in the read buffer + */ public TFRecordDataset( Ops tf, Operand filenames, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TakeDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TakeDataset.java index 39ca9759e74..49465ab5fd2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TakeDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TakeDataset.java @@ -17,16 +17,29 @@ import org.tensorflow.Operand; import org.tensorflow.framework.data.Dataset; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Constant; -import org.tensorflow.ndarray.Shape; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; import java.util.List; -import org.tensorflow.types.family.TType; +/** A Dataset with at most count elements from this dataset. */ public class TakeDataset extends Dataset { + /** + * Creates a Dataset with at most count elements from this dataset. + * + * @param tf The TensorFlow Ops + * @param variant the Operand that represents the dataset. + * @param count the number of elements of this dataset that should be skipped to form the new + * dataset. + * @param outputTypes A list of classes corresponding to the tensor type of each component of a + * dataset element. + * @param outputShapes A list of `Shape` objects corresponding to the shapes of each component of + * a dataset element. + */ public TakeDataset( Ops tf, Operand variant, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TensorSliceDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TensorSliceDataset.java index 46639ea2aad..dcc703e3ac1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TensorSliceDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TensorSliceDataset.java @@ -17,23 +17,54 @@ import org.tensorflow.Operand; import org.tensorflow.framework.data.Dataset; -import org.tensorflow.op.Ops; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TType; import java.util.List; import java.util.stream.Collectors; -import org.tensorflow.types.family.TType; +/** + * A Dataset whose elements are slices of the given tensors. + * + *

The given tensors are sliced along their first dimension. This operation preserves the + * structure of the input tensors, removing the first dimension of each tensor and using it as the + * dataset dimension. All input tensors must have the same size in their first dimensions. + */ public class TensorSliceDataset extends Dataset { - public TensorSliceDataset(Ops tf, List> components, List> outputTypes) { + /** + * Creates a Dataset whose elements are slices of the given tensors. + * + * @param tf the TensorFlow Ops + * @param components the conpoents to slice + * @param outputTypes A list of classes corresponding to the tensor type of each component of a + * dataset element. + */ + public TensorSliceDataset( + Ops tf, List> components, List> outputTypes) { super(tf, makeVariant(tf, components, outputTypes), outputTypes, outputShapes(components)); } + /** + * Gets the list of Shapes for the components. + * + * @param components the list of components + * @return the output shapes for the components. + */ private static List outputShapes(List> components) { return components.stream().map(c -> c.shape().tail()).collect(Collectors.toList()); } + /** + * Makes the variant Operand from the components + * + * @param tf the TensorFlow Ops + * @param components the list of components + * @param outputTypes list of classes corresponding to the tensor type of each component of a * + * dataset element. + * @return the variant Operand that represents this dataset. + */ private static Operand makeVariant( Ops tf, List> components, List> outputTypes) { if (!(components.size() == outputTypes.size())) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TextLineDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TextLineDataset.java index c9a26304778..e9967c746f1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TextLineDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TextLineDataset.java @@ -17,15 +17,24 @@ import org.tensorflow.Operand; import org.tensorflow.framework.data.Dataset; -import org.tensorflow.op.Ops; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import java.util.Collections; +/** A Dataset comprising lines from one or more text files. */ public class TextLineDataset extends Dataset { + /** + * Creates a Dataset comprising lines from one or more text files. + * + * @param tf the TensorFlow Ops + * @param filenames the names of one or more files containing the text lines + * @param compressionType the compression type, either "" (no compression), "ZLIB", or "GZIP" + * @param bufferSize the number of bytes in the read buffer + */ public TextLineDataset( Ops tf, Operand filenames,