diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index ad189bb59ff..24ede7001b4 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -170,10 +170,11 @@ For dependencies, we can use anything compliant with [this list](https://opensou
### Code generation
-Code generation for `Ops` and related classes is done during `tensorflow-core-api`'s `compile` phase, using the annotation processor in
-`tensorflow-core-generator`. If you change or add any operator classes (annotated with `org.tensorflow.op.annotation.Operator`), endpoint methods (
-annotated with `org.tensorflow.op.annotation.Endpoint`), or change the annotation processor, be sure to re-run a
-`mvn install` in `tensorflow-core-api` (`-Pdev` is fine for this, it just needs to run the annotation processor).
+Code generation for `Ops` and related classes is done during `tensorflow-core-api` and `tensorflow-core-kotlin`'s `compile` phase,
+using the annotation processors in `tensorflow-core-generator` and `tensorflow-kotlin-generator`, respectively. If you change or add any
+operator classes (annotated with `org.tensorflow.op.annotation.Operator`), endpoint methods (annotated with `org.tensorflow.op.annotation.Endpoint`),
+or change the annotation processor, be sure to re-run a `mvn compile` in `tensorflow-core-api` **and** `tensorflow-core-kotlin`
+(`-Pdev` is fine for this, it just needs to run the annotation processor).
### Working with Bazel generation
@@ -189,6 +190,19 @@ bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/libtensorflow_cc.so --ou
(called in `tensorflow-core-api`).
+### Kotlin API
+
+The Kotlin api should be kept to a thin wrapper of the Java API, using extension functions and codegen wherever possible.
+We do not want to get into a situation where we are maintaining two separate but related APIs.
+
+The codegen (`tensorflow-core-kotlin-generator`) is an annotation processor that reads the `@Operator` classes from the `tensorflow-core-api` Java sources.
+If you add operators or re-generate them from the native library, be sure to re-run a `mvn install` in `tensorflow-core-kotlin-api`.
+
+#### Formatting
+
+[ktfmt](https://github.com/facebookincubator/ktfmt) is used to format the Kotlin files. This is
+checked and done via maven in the same way as Java formatting. To do the formatting via IntelliJ see
+ktfmt's repo.
## Adding Gradients
diff --git a/README.md b/README.md
index 305fb1e759a..3eab22ea048 100644
--- a/README.md
+++ b/README.md
@@ -21,8 +21,13 @@ The following describes the layout of the repository and its different artifacts
* `tensorflow-core`
* All artifacts that build up the core language bindings of TensorFlow for Java
* Intended audience: projects that provide their own APIs or frameworks on top of
- TensorFlow and just want a thin layer to access the TensorFlow runtime from the JVM
+ TensorFlow and just want a thin layer to access the TensorFlow runtime from the JVM
+* `tensorflow-core-kotlin`
+ * Kotlin API bindings for `tensorflow-core`. These are thin wrappers around the core APIs
+ to make them more idiomatic for use in Kotlin, such as using parameters with default values
+ operation builders instead of an `Options` vararg.
+
* `tensorflow-framework`
* Primary API for building and training neural networks with TensorFlow
* Intended audience: neural network developers
@@ -112,6 +117,12 @@ the platforms you are targeting. For this purpose the `-platform` artifacts incl
the conventions established on this page:
* [Reducing the Number of Dependencies](https://github.com/bytedeco/javacpp-presets/wiki/Reducing-the-Number-of-Dependencies)
+### Kotlin API
+
+Since the Kotlin API is just a wrapper of the Java API, it uses the Java platform artifacts instead of providing its own.
+To use, follow the instructions above for the Java API, but add `tensorflow-core-kotlin-api`,
+replacing `tensorflow-core-api` if you have explicitly included it.
+
### Snapshots
Snapshots of TensorFlow Java artifacts are automatically distributed after each update in the code. To use them, you need
diff --git a/pom.xml b/pom.xml
index f4f1b18928b..ed123e9228b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -32,6 +32,7 @@
tensorflow-core
+ tensorflow-kotlin-parenttensorflow-framework
diff --git a/tensorflow-core/tensorflow-core-api/pom.xml b/tensorflow-core/tensorflow-core-api/pom.xml
index b6f9da1a2bd..142aac1065f 100644
--- a/tensorflow-core/tensorflow-core-api/pom.xml
+++ b/tensorflow-core/tensorflow-core-api/pom.xml
@@ -20,7 +20,7 @@
${native.build.skip}${native.build.skip}org.tensorflow.core.api
- 0.3.3
+ 0.4.0-SNAPSHOT1.0.1
diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
index 223754b0480..f1a3ab1dd79 100644
--- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
+++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
@@ -340,7 +340,7 @@
* }
* }
*/
-public final class Ops {
+public final class Ops implements WithOps {
public final NnOps nn;
public final SummaryOps summary;
@@ -371,10 +371,10 @@ public final class Ops {
public final TpuOps tpu;
- public final AudioOps audio;
-
public final MathOps math;
+ public final AudioOps audio;
+
public final SignalOps signal;
public final TrainOps train;
@@ -400,8 +400,8 @@ public final class Ops {
sparse = new SparseOps(this);
bitwise = new BitwiseOps(this);
tpu = new TpuOps(this);
- audio = new AudioOps(this);
math = new MathOps(this);
+ audio = new AudioOps(this);
signal = new SignalOps(this);
train = new TrainOps(this);
quantization = new QuantizationOps(this);
@@ -8068,11 +8068,15 @@ public ZerosLike zerosLike(Operand x) {
return ZerosLike.create(scope, x);
}
+ @Override
+ public Ops tf() {
+ return this;
+ }
+
/**
- * Returns an API that builds operations with the provided name prefix.
- *
- * @see {@link Scope#withSubScope(String)}
+ * {@inheritDoc}
*/
+ @Override
public Ops withSubScope(String childScopeName) {
return new Ops(scope.withSubScope(childScopeName));
}
@@ -8109,28 +8113,25 @@ public T liftToInitScope(T op) {
}
/**
- * Returns an API that uses the provided name for an op.
- *
- * @see {@link Scope#withName(String)}
+ * {@inheritDoc}
*/
+ @Override
public Ops withName(String opName) {
return new Ops(scope.withName(opName));
}
/**
- * Returns an API that places the created operations on the device(s) matching the provided spec.
- *
- * @see {@link Scope#withDevice(DeviceSpec)}
+ * {@inheritDoc}
*/
+ @Override
public Ops withDevice(DeviceSpec deviceSpec) {
return new Ops(scope.withDevice(deviceSpec));
}
/**
- * Returns an API that adds operations to the graph with the provided control dependencies.
- *
- * @see {@link Scope#withControlDependencies(Iterable>)}
+ * {@inheritDoc}
*/
+ @Override
public Ops withControlDependencies(Iterable controls) {
return new Ops(scope.withControlDependencies(controls));
}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java
index 87745138f01..845efa92fb8 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java
@@ -16,10 +16,12 @@
package org.tensorflow;
import org.tensorflow.op.Op;
+import org.tensorflow.op.Ops;
import org.tensorflow.op.Scope;
+import org.tensorflow.op.WithOps;
/** Defines an environment for creating and executing TensorFlow {@link Operation}s. */
-public interface ExecutionEnvironment {
+public interface ExecutionEnvironment extends WithOps {
enum Types {
GRAPH,
@@ -126,4 +128,9 @@ default ExecutionEnvironment initEnv() {
*
Should generally only be used internally.
*/
boolean isInitOp(Operation op);
+
+ @Override
+ default Ops tf() {
+ return Ops.create(this);
+ }
}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java
index 251f5a6e4b3..a171bbe3108 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java
@@ -1,18 +1,18 @@
/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- =======================================================================
- */
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================
+*/
package org.tensorflow;
import java.util.Collections;
@@ -161,7 +161,7 @@ private static TensorInfo toTensorInfo(Output> operand) {
Shape shape = operand.shape();
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
for (int i = 0; i < shape.numDimensions(); ++i) {
- tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i)));
+ tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.get(i)));
}
return TensorInfo.newBuilder()
.setDtype(operand.dataType())
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/WithOps.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/WithOps.java
new file mode 100644
index 00000000000..474127b4ca1
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/WithOps.java
@@ -0,0 +1,73 @@
+/*
+ Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================
+
+*/
+package org.tensorflow.op;
+
+import java.util.Arrays;
+import org.tensorflow.DeviceSpec;
+
+/** A context that provides a TensorFlow op builder. */
+public interface WithOps {
+
+ /** Get the op builder for this context. */
+ Ops tf();
+
+ /**
+ * Returns an API that builds operations with the provided name prefix.
+ *
+ * @see Scope#withSubScope(String)
+ */
+ default WithOps withSubScope(String childScopeName) {
+ return tf().withSubScope(childScopeName);
+ }
+
+ /**
+ * Returns an API that uses the provided name for an op.
+ *
+ * @see Scope#withName(String)
+ */
+ default WithOps withName(String opName) {
+ return tf().withName(opName);
+ }
+
+ /**
+ * Returns an API that places the created operations on the device(s) matching the provided spec.
+ *
+ * @see Scope#withDevice(DeviceSpec)
+ */
+ default WithOps withDevice(DeviceSpec deviceSpec) {
+ return tf().withDevice(deviceSpec);
+ }
+
+ /**
+ * Returns an API that adds operations to the graph with the provided control dependencies.
+ *
+ * @see Scope#withControlDependencies(Iterable)
+ */
+ default WithOps withControlDependencies(Iterable controls) {
+ return tf().withControlDependencies(controls);
+ }
+
+ /**
+ * Returns an API that adds operations to the graph with the provided control dependencies.
+ *
+ * @see Scope#withControlDependencies(Iterable)
+ */
+ default WithOps withControlDependencies(Op... controls) {
+ return withControlDependencies(Arrays.asList(controls));
+ }
+}
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java
index e2dc82f4c48..fbed7861ed9 100644
--- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java
@@ -55,8 +55,8 @@ public void outputDataTypeAndShape() {
.setAttr("value", t)
.build();
assertEquals(DataType.DT_INT32, op.dtype(0));
- assertEquals(2, op.shape(0).size(0));
- assertEquals(3, op.shape(0).size(1));
+ assertEquals(2, op.shape(0).get(0));
+ assertEquals(3, op.shape(0).get(1));
}
}
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java
index 84e1e56df56..5b9b8d059da 100644
--- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java
@@ -144,8 +144,8 @@ public void setAttrShape() {
.build()
.output(0);
assertEquals(2, n.shape().numDimensions());
- assertEquals(-1, n.shape().size(0));
- assertEquals(784, n.shape().size(1));
+ assertEquals(-1, n.shape().get(0));
+ assertEquals(784, n.shape().get(1));
assertEquals(DataType.DT_FLOAT, n.dataType());
}
}
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java
index be6f952fb6a..8e3f742b6bb 100644
--- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java
@@ -200,7 +200,7 @@ public void exportFunctionWithVariables() throws IOException {
assertNotNull(inputInfo);
assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount());
for (int i = 0; i < xyShape.numDimensions(); ++i) {
- assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize());
+ assertEquals(xyShape.get(i), inputInfo.getTensorShape().getDim(i).getSize());
}
TensorInfo outputInfo = signatureDef.getOutputsMap().get("reducedSum");
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java
index 9415a986222..0d3015d0445 100644
--- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java
@@ -66,7 +66,7 @@ public void createWithRawData() {
Shape strings_shape = Shape.scalar();
byte[] strings_; // raw TF_STRING
try (TString t = TString.tensorOf(NdArrays.scalarOfObject(strings))) {
- strings_ = new byte[(int)t.numBytes()];
+ strings_ = new byte[(int) t.numBytes()];
t.asRawTensor().data().read(strings_);
}
@@ -86,8 +86,11 @@ public void createWithRawData() {
// validate creating a tensor using a direct byte buffer (in host order)
{
- DoubleBuffer buf = ByteBuffer.allocateDirect(8 * doubles.length).order(ByteOrder.nativeOrder())
- .asDoubleBuffer().put(doubles);
+ DoubleBuffer buf =
+ ByteBuffer.allocateDirect(8 * doubles.length)
+ .order(ByteOrder.nativeOrder())
+ .asDoubleBuffer()
+ .put(doubles);
try (TFloat64 t = TFloat64.tensorOf(doubles_shape, d -> d.write(DataBuffers.of(buf)))) {
double[] actual = new double[doubles.length];
t.read(DataBuffers.of(actual));
@@ -140,10 +143,10 @@ public void createFromBufferWithNonNativeByteOrder() {
@Test
public void createWithTypedBuffer() {
- IntBuffer ints = IntBuffer.wrap(new int[]{1, 2, 3, 4});
- FloatBuffer floats = FloatBuffer.wrap(new float[]{1f, 2f, 3f, 4f});
- DoubleBuffer doubles = DoubleBuffer.wrap(new double[]{1d, 2d, 3d, 4d});
- LongBuffer longs = LongBuffer.wrap(new long[]{1L, 2L, 3L, 4L});
+ IntBuffer ints = IntBuffer.wrap(new int[] {1, 2, 3, 4});
+ FloatBuffer floats = FloatBuffer.wrap(new float[] {1f, 2f, 3f, 4f});
+ DoubleBuffer doubles = DoubleBuffer.wrap(new double[] {1d, 2d, 3d, 4d});
+ LongBuffer longs = LongBuffer.wrap(new long[] {1L, 2L, 3L, 4L});
// validate creating a tensor using a typed buffer
{
@@ -243,7 +246,7 @@ public void readFromRawData() {
// validate the use of direct buffers
{
ByteBuffer bbuf =
- ByteBuffer.allocateDirect((int)tdoubles.numBytes()).order(ByteOrder.nativeOrder());
+ ByteBuffer.allocateDirect((int) tdoubles.numBytes()).order(ByteOrder.nativeOrder());
tdoubles.asRawTensor().data().copyTo(DataBuffers.of(bbuf), tdoubles.numBytes());
assertEquals(doubles[0], bbuf.asDoubleBuffer().get(0), EPSILON);
}
@@ -251,13 +254,17 @@ public void readFromRawData() {
// validate byte order conversion
{
DoubleBuffer foreignBuf =
- ByteBuffer.allocate((int)tdoubles.numBytes())
+ ByteBuffer.allocate((int) tdoubles.numBytes())
.order(
ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN
? ByteOrder.BIG_ENDIAN
: ByteOrder.LITTLE_ENDIAN)
.asDoubleBuffer();
- tdoubles.asRawTensor().data().asDoubles().copyTo(DataBuffers.of(foreignBuf), foreignBuf.capacity());
+ tdoubles
+ .asRawTensor()
+ .data()
+ .asDoubles()
+ .copyTo(DataBuffers.of(foreignBuf), foreignBuf.capacity());
double[] actual = new double[foreignBuf.remaining()];
foreignBuf.get(actual);
assertArrayEquals(doubles, actual, EPSILON);
@@ -320,51 +327,55 @@ public void scalars() {
@Test
public void nDimensional() {
- DoubleNdArray vector = StdArrays.ndCopyOf(new double[]{1.414, 2.718, 3.1415});
+ DoubleNdArray vector = StdArrays.ndCopyOf(new double[] {1.414, 2.718, 3.1415});
try (TFloat64 t = TFloat64.tensorOf(vector)) {
assertEquals(TFloat64.class, t.type());
assertEquals(DataType.DT_DOUBLE, t.dataType());
assertEquals(1, t.shape().numDimensions());
- assertEquals(3, t.shape().size(0));
+ assertEquals(3, t.shape().get(0));
assertEquals(vector, t);
}
- IntNdArray matrix = StdArrays.ndCopyOf(new int[][]{{1, 2, 3}, {4, 5, 6}});
+ IntNdArray matrix = StdArrays.ndCopyOf(new int[][] {{1, 2, 3}, {4, 5, 6}});
try (TInt32 t = TInt32.tensorOf(matrix)) {
assertEquals(TInt32.class, t.type());
assertEquals(DataType.DT_INT32, t.dataType());
assertEquals(2, t.shape().numDimensions());
- assertEquals(2, t.shape().size(0));
- assertEquals(3, t.shape().size(1));
+ assertEquals(2, t.shape().get(0));
+ assertEquals(3, t.shape().get(1));
assertEquals(matrix, t);
}
- LongNdArray threeD = StdArrays.ndCopyOf(new long[][][]{
- {{1}, {3}, {5}, {7}, {9}}, {{2}, {4}, {6}, {8}, {0}},
- });
+ LongNdArray threeD =
+ StdArrays.ndCopyOf(
+ new long[][][] {
+ {{1}, {3}, {5}, {7}, {9}}, {{2}, {4}, {6}, {8}, {0}},
+ });
try (TInt64 t = TInt64.tensorOf(threeD)) {
assertEquals(TInt64.class, t.type());
assertEquals(DataType.DT_INT64, t.dataType());
assertEquals(3, t.shape().numDimensions());
- assertEquals(2, t.shape().size(0));
- assertEquals(5, t.shape().size(1));
- assertEquals(1, t.shape().size(2));
+ assertEquals(2, t.shape().get(0));
+ assertEquals(5, t.shape().get(1));
+ assertEquals(1, t.shape().get(2));
assertEquals(threeD, t);
}
- BooleanNdArray fourD = StdArrays.ndCopyOf(new boolean[][][][]{
- {{{false, false, false, true}, {false, false, true, false}}},
- {{{false, false, true, true}, {false, true, false, false}}},
- {{{false, true, false, true}, {false, true, true, false}}},
- });
+ BooleanNdArray fourD =
+ StdArrays.ndCopyOf(
+ new boolean[][][][] {
+ {{{false, false, false, true}, {false, false, true, false}}},
+ {{{false, false, true, true}, {false, true, false, false}}},
+ {{{false, true, false, true}, {false, true, true, false}}},
+ });
try (TBool t = TBool.tensorOf(fourD)) {
assertEquals(TBool.class, t.type());
assertEquals(DataType.DT_BOOL, t.dataType());
assertEquals(4, t.shape().numDimensions());
- assertEquals(3, t.shape().size(0));
- assertEquals(1, t.shape().size(1));
- assertEquals(2, t.shape().size(2));
- assertEquals(4, t.shape().size(3));
+ assertEquals(3, t.shape().get(0));
+ assertEquals(1, t.shape().get(1));
+ assertEquals(2, t.shape().get(2));
+ assertEquals(4, t.shape().get(3));
assertEquals(fourD, t);
}
}
@@ -381,19 +392,21 @@ public void testNDimensionalStringTensor() {
assertEquals(TString.class, t.type());
assertEquals(DataType.DT_STRING, t.dataType());
assertEquals(2, t.shape().numDimensions());
- assertEquals(4, t.shape().size(0));
- assertEquals(3, t.shape().size(1));
+ assertEquals(4, t.shape().get(0));
+ assertEquals(3, t.shape().get(1));
assertEquals(matrix, t);
}
NdArray byteMatrix = NdArrays.ofObjects(byte[].class, matrix.shape());
- matrix.scalars().forEachIndexed((i, s) -> byteMatrix.setObject(s.getObject().getBytes(UTF_8), i));
+ matrix
+ .scalars()
+ .forEachIndexed((i, s) -> byteMatrix.setObject(s.getObject().getBytes(UTF_8), i));
try (TString t = TString.tensorOfBytes(byteMatrix)) {
assertEquals(TString.class, t.type());
assertEquals(DataType.DT_STRING, t.dataType());
assertEquals(2, t.shape().numDimensions());
- assertEquals(4, t.shape().size(0));
- assertEquals(3, t.shape().size(1));
+ assertEquals(4, t.shape().get(0));
+ assertEquals(3, t.shape().get(1));
assertEquals(byteMatrix, t.asBytes());
assertEquals(matrix, t);
}
@@ -406,7 +419,7 @@ public void testUint8TensorFromArray() {
assertEquals(TUint8.class, t.type());
assertEquals(DataType.DT_UINT8, t.dataType());
assertEquals(1, t.shape().numDimensions());
- assertEquals(4, t.shape().size(0));
+ assertEquals(4, t.shape().get(0));
byte[] got = new byte[4];
t.read(DataBuffers.of(got));
@@ -421,7 +434,7 @@ public void testCreateFromArrayOfBoxed() {
assertEquals(TInt32.class, t.type());
assertEquals(DataType.DT_INT32, t.dataType());
assertEquals(1, t.shape().numDimensions());
- assertEquals(4, t.shape().size(0));
+ assertEquals(4, t.shape().get(0));
Integer[] got = new Integer[4];
t.read(DataBuffers.ofObjects(got));
@@ -512,9 +525,10 @@ public void fromHandle() {
//
// An exception is made for this test, where the pitfalls of this is avoided by not calling
// close() on both Tensors.
- final FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{{1, 2, 3}, {4, 5, 6}});
+ final FloatNdArray matrix = StdArrays.ndCopyOf(new float[][] {{1, 2, 3}, {4, 5, 6}});
try (TFloat32 src = TFloat32.tensorOf(matrix)) {
- TFloat32 cpy = (TFloat32)RawTensor.fromHandle(src.asRawTensor().nativeHandle()).asTypedTensor();
+ TFloat32 cpy =
+ (TFloat32) RawTensor.fromHandle(src.asRawTensor().nativeHandle()).asTypedTensor();
assertEquals(src.type(), cpy.type());
assertEquals(src.dataType(), cpy.dataType());
assertEquals(src.shape().numDimensions(), cpy.shape().numDimensions());
diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/Names.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/Names.java
index 7252d258814..958b74de1bf 100644
--- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/Names.java
+++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/Names.java
@@ -65,6 +65,8 @@ public class Names {
public static final TypeName ArrayOp = ArrayTypeName.of(Op);
public static final TypeName ArrayOperation = ArrayTypeName.of(Operation);
+ public static final ClassName WithOps = ClassName.get(OpPackage, "WithOps");
+
public static final ClassName Operand = ClassName.get(TensorflowPackage, "Operand");
public static final ClassName Output = ClassName.get(TensorflowPackage, "Output");
diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/BaseOperatorProcessor.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/BaseOperatorProcessor.java
new file mode 100644
index 00000000000..793a7aa7b57
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/BaseOperatorProcessor.java
@@ -0,0 +1,557 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.processor.operator;
+
+import com.github.javaparser.ast.comments.JavadocComment;
+import com.github.javaparser.javadoc.Javadoc;
+import com.google.common.base.CaseFormat;
+import com.google.common.base.Strings;
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.Multimap;
+import com.squareup.javapoet.ClassName;
+import com.squareup.javapoet.MethodSpec;
+import com.squareup.javapoet.ParameterSpec;
+import com.squareup.javapoet.ParameterizedTypeName;
+import com.squareup.javapoet.TypeName;
+import com.squareup.javapoet.TypeVariableName;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import javax.annotation.processing.AbstractProcessor;
+import javax.annotation.processing.Filer;
+import javax.annotation.processing.Messager;
+import javax.annotation.processing.ProcessingEnvironment;
+import javax.annotation.processing.RoundEnvironment;
+import javax.lang.model.SourceVersion;
+import javax.lang.model.element.AnnotationMirror;
+import javax.lang.model.element.AnnotationValue;
+import javax.lang.model.element.Element;
+import javax.lang.model.element.ExecutableElement;
+import javax.lang.model.element.Modifier;
+import javax.lang.model.element.Name;
+import javax.lang.model.element.TypeElement;
+import javax.lang.model.element.TypeParameterElement;
+import javax.lang.model.element.VariableElement;
+import javax.lang.model.type.NoType;
+import javax.lang.model.type.TypeMirror;
+import javax.lang.model.type.TypeVariable;
+import javax.lang.model.util.ElementFilter;
+import javax.lang.model.util.Elements;
+import javax.lang.model.util.Types;
+import javax.tools.Diagnostic.Kind;
+
+/**
+ * A compile-time Processor that aggregates classes annotated with {@code
+ * org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please
+ * refer to the {@code Operator} annotation for details about the API generated for each annotated
+ * class.
+ *
+ *
Note that this processor can only be invoked once, in a single compilation run that includes
+ * all the {@code Operator} annotated source classes. The reason is that the {@code Ops} API is an
+ * "aggregating" API, and annotation processing does not permit modifying an already generated
+ * class.
+ */
+public abstract class BaseOperatorProcessor extends AbstractProcessor {
+
+ @Override
+ public SourceVersion getSupportedSourceVersion() {
+ return SourceVersion.latest();
+ }
+
+ @Override
+ public synchronized void init(ProcessingEnvironment processingEnv) {
+ super.init(processingEnv);
+ messager = processingEnv.getMessager();
+ filer = processingEnv.getFiler();
+ elements = processingEnv.getElementUtils();
+ types = processingEnv.getTypeUtils();
+ }
+
+ @Override
+ public boolean process(Set extends TypeElement> annotations, RoundEnvironment roundEnv) {
+ // Nothing needs to be done at the end of all rounds.
+ if (roundEnv.processingOver()) {
+ return false;
+ }
+
+ // Nothing to look at in this round.
+ if (annotations.size() == 0) {
+ return false;
+ }
+
+ // We expect to be registered for exactly one annotation.
+ if (annotations.size() != 1) {
+ throw new IllegalStateException(
+ "Unexpected - multiple annotations registered: " + annotations);
+ }
+ TypeElement annotation = annotations.iterator().next();
+ Set extends Element> annotated = roundEnv.getElementsAnnotatedWith(annotation);
+
+ // If there are no annotated elements, claim the annotation but do nothing.
+ if (annotated.size() == 0) {
+ return true;
+ }
+
+ // This processor has to aggregate all op classes in one round, as it generates a single Ops
+ // API class which cannot be modified once generated. If we find an annotation after we've
+ // generated our code, flag the location of each such class.
+ if (hasRun) {
+ for (Element e : annotated) {
+ error(
+ e,
+ "The Operator processor has already processed @Operator annotated sources\n"
+ + "and written out an Ops API. It cannot process additional @Operator sources.\n"
+ + "One reason this can happen is if other annotation processors generate\n"
+ + "new @Operator source files.");
+ }
+ return true;
+ }
+
+ // Collect all classes tagged with our annotation.
+ Multimap groupedMethods = HashMultimap.create();
+ if (!collectOpsMethods(roundEnv, groupedMethods, annotation)) {
+ return true;
+ }
+
+ // Nothing to do when there are no tagged classes.
+ if (groupedMethods.isEmpty()) {
+ return true;
+ }
+
+ // Validate operator classes and generate Op API.
+ writeApi(groupedMethods);
+
+ hasRun = true;
+ return true;
+ }
+
+ @Override
+ public Set getSupportedAnnotationTypes() {
+ return Collections.singleton("org.tensorflow.op.annotation.Operator");
+ }
+
+ protected static class OpsSpec {
+ protected static final Comparator PARAMETER_SPEC_COMPARATOR =
+ (o1, o2) -> {
+ if (o1.javaMethod.parameters.size() > o2.javaMethod.parameters.size()) {
+ return 1;
+ }
+ if (o1.javaMethod.parameters.size() < o2.javaMethod.parameters.size()) {
+ return -1;
+ }
+ List firstParams = o1.javaMethod.parameters;
+ List secondParams = o2.javaMethod.parameters;
+ for (int i = 0; i < firstParams.size(); i++) {
+ ParameterSpec first = firstParams.get(i);
+ ParameterSpec second = secondParams.get(i);
+ int compare = first.name.compareTo(second.name);
+ if (compare != 0) {
+ return compare;
+ }
+ }
+ return 0;
+ };
+ protected static final Comparator METHOD_SPEC_COMPARATOR =
+ Comparator.comparing((OpMethod m) -> m.name).thenComparing(PARAMETER_SPEC_COMPARATOR);
+
+ public final @Nullable OpsSpec parent;
+ public final String groupName;
+ public final String fieldName;
+ public final ClassName className;
+ public final List methods;
+ public final List subGroups = new ArrayList<>();
+
+ OpsSpec(
+ OpsSpec parent,
+ String groupName,
+ String fieldName,
+ ClassName className,
+ Collection methods) {
+ this.parent = parent;
+ this.groupName = groupName;
+ this.fieldName = fieldName;
+ this.className = className;
+ this.methods = new ArrayList<>(methods);
+ this.methods.sort(METHOD_SPEC_COMPARATOR);
+ }
+
+ Iterable javaMethods() {
+ return methods.stream().map(x -> x.javaMethod).collect(Collectors.toList());
+ }
+ }
+
+ protected static final class OpMethod {
+ final String name;
+ final TypeElement opClass;
+ final ExecutableElement endpointMethod;
+ final boolean describeByClass;
+ final boolean deprecated;
+ final MethodSpec javaMethod;
+
+ public OpMethod(
+ String name,
+ TypeElement opClass,
+ ExecutableElement endpointMethod,
+ boolean describeByClass,
+ boolean deprecated,
+ MethodSpec javaMethod) {
+ this.name = name;
+ this.opClass = opClass;
+ this.endpointMethod = endpointMethod;
+ this.describeByClass = describeByClass;
+ this.deprecated = deprecated;
+ this.javaMethod = javaMethod;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof OpMethod)) {
+ return false;
+ }
+
+ OpMethod opMethod = (OpMethod) o;
+
+ return javaMethod.equals(opMethod.javaMethod);
+ }
+
+ @Override
+ public int hashCode() {
+ return javaMethod.hashCode();
+ }
+ }
+
+ protected static final Pattern JAVADOC_TAG_PATTERN =
+ Pattern.compile("@(?:param|return|throws|exception|see|deprecated)\\s+.*");
+ protected static final ClassName T_OP = ClassName.get("org.tensorflow.op", "Op");
+ protected static final ClassName T_OPS = ClassName.get("org.tensorflow.op", "Ops");
+ protected static final TypeName T_ITERABLE_OP =
+ ParameterizedTypeName.get(ClassName.get(Iterable.class), T_OP);
+ protected static final ClassName T_OPERATOR =
+ ClassName.get("org.tensorflow.op.annotation", "Operator");
+ protected static final ClassName T_ENDPOINT =
+ ClassName.get("org.tensorflow.op.annotation", "Endpoint");
+ protected static final ClassName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope");
+ protected static final ClassName T_EXEC_ENV =
+ ClassName.get("org.tensorflow", "ExecutionEnvironment");
+ protected static final ClassName T_EAGER_SESSION =
+ ClassName.get("org.tensorflow", "EagerSession");
+ protected static final ClassName T_STRING = ClassName.get(String.class);
+
+ protected static final String LICENSE =
+ "Copyright 2020 The TensorFlow Authors. All Rights Reserved.\n"
+ + "\n"
+ + "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
+ + "you may not use this file except in compliance with the License.\n"
+ + "You may obtain a copy of the License at\n"
+ + "\n"
+ + " http://www.apache.org/licenses/LICENSE-2.0\n"
+ + "\n"
+ + "Unless required by applicable law or agreed to in writing, software\n"
+ + "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
+ + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
+ + "See the License for the specific language governing permissions and\n"
+ + "limitations under the License.\n"
+ + "==============================================================================\n";
+
+ protected Filer filer;
+ protected Messager messager;
+ protected Elements elements;
+ protected Types types;
+ protected boolean hasRun = false;
+
+ protected void error(Element e, String message, Object... args) {
+ if (args != null && args.length > 0) {
+ message = String.format(message, args);
+ }
+ messager.printMessage(Kind.ERROR, message, e);
+ }
+
+ protected abstract void write(T spec);
+
+ protected void writeApi(Multimap groupedMethods) {
+ // Build tree of *Ops classes that needs to be generated by this processor. The 'Ops' class
+ // resides at the root of the tree while other classes are nodes.
+ OpsSpec ops = new OpsSpec(null, null, null, T_OPS, groupedMethods.removeAll(""));
+ Collection groupOps = collectGroupOps(ops, groupedMethods);
+
+ write(buildTopClass(ops));
+ groupOps.forEach(g -> write(buildGroupClass(g)));
+ }
+
+ protected boolean collectOpsMethods(
+ RoundEnvironment roundEnv,
+ Multimap groupedMethods,
+ TypeElement annotation) {
+ boolean result = true;
+ for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) {
+ // @Operator can only apply to types, so e must be a TypeElement.
+ if (!(e instanceof TypeElement)) {
+ error(
+ e,
+ "@Operator can only be applied to classes, but this is a %s",
+ e.getKind().toString());
+ result = false;
+ continue;
+ }
+ collectOpMethods(groupedMethods, (TypeElement) e, annotation);
+ }
+ return result;
+ }
+
+ protected void collectOpMethods(
+ Multimap groupedMethods, TypeElement opClass, TypeElement annotation) {
+ boolean opClassDeprecated = opClass.getAnnotation(Deprecated.class) != null;
+ AnnotationMirror operatorAnnot = getAnnotationMirror(opClass, annotation.getQualifiedName());
+ if (operatorAnnot == null) {
+ throw new IllegalArgumentException(
+ "Annotation "
+ + annotation.getSimpleName()
+ + " not present on element "
+ + opClass.getSimpleName());
+ }
+ String opGroup = getAnnotationElementValueAsString("group", operatorAnnot);
+ String opName = getAnnotationElementValueAsString("name", operatorAnnot);
+ if (Strings.isNullOrEmpty(opName)) {
+ opName =
+ CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, ClassName.get(opClass).simpleName());
+ }
+ // Build an endpoint for each method annotated with @Endpoint, which takes in parameter a scope
+ // and, optionally, a list of arguments
+ for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) {
+ AnnotationMirror endpointAnnot =
+ getAnnotationMirror(opMethod, elements.getName(T_ENDPOINT.toString()));
+ if (endpointAnnot != null) {
+ if (!opMethod.getModifiers().containsAll(Arrays.asList(Modifier.STATIC, Modifier.PUBLIC))) {
+ throw new IllegalArgumentException(
+ "Endpoint " + opMethod + " of class " + opClass + " must be static and public");
+ }
+ if (opMethod.getParameters().isEmpty()
+ || !((TypeElement) types.asElement(opMethod.getParameters().get(0).asType()))
+ .getQualifiedName()
+ .equals(elements.getName(T_SCOPE.toString()))) {
+ throw new IllegalArgumentException(
+ "Endpoint "
+ + opMethod
+ + " of class "
+ + opClass
+ + " must take an instance of "
+ + T_SCOPE
+ + " as its first parameter");
+ }
+ String endpointGroup = getAnnotationElementValueAsString("group", endpointAnnot);
+ if (endpointGroup.isEmpty()) {
+ endpointGroup = opGroup;
+ }
+ String endpointName = getAnnotationElementValueAsString("name", endpointAnnot);
+ if (endpointName.isEmpty()) {
+ endpointName = opName;
+ }
+ boolean describeByClass =
+ getAnnotationElementValueAsBoolean("describeByClass", endpointAnnot, false);
+ boolean deprecated = opMethod.getAnnotation(Deprecated.class) != null || opClassDeprecated;
+ OpMethod method =
+ buildOpMethod(endpointName, opClass, opMethod, describeByClass, deprecated);
+ groupedMethods.put(endpointGroup, method);
+ }
+ }
+ }
+
+ protected OpMethod buildOpMethod(
+ String methodName,
+ TypeElement opClass,
+ ExecutableElement endpointMethod,
+ boolean describeByClass,
+ boolean deprecated) {
+ MethodSpec.Builder builder =
+ MethodSpec.methodBuilder(methodName)
+ .addModifiers(Modifier.PUBLIC)
+ .returns(TypeName.get(endpointMethod.getReturnType()))
+ .varargs(endpointMethod.isVarArgs())
+ .addJavadoc(
+ "$L", buildOpMethodJavadoc(opClass, endpointMethod, describeByClass).toText());
+
+ if (deprecated) {
+ builder.addAnnotation(Deprecated.class);
+ }
+ for (TypeParameterElement tp : endpointMethod.getTypeParameters()) {
+ TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType());
+ builder.addTypeVariable(tvn);
+ }
+ for (TypeMirror thrownType : endpointMethod.getThrownTypes()) {
+ builder.addException(TypeName.get(thrownType));
+ }
+ StringBuilder call = new StringBuilder();
+ if (!NoType.class.isAssignableFrom(endpointMethod.getReturnType().getClass())) {
+ call.append("return ");
+ }
+ call.append("$T.").append(endpointMethod.getSimpleName()).append("(scope");
+ boolean first = true;
+ for (VariableElement param : endpointMethod.getParameters()) {
+ ParameterSpec p = ParameterSpec.get(param);
+ if (first) {
+ first = false;
+ continue;
+ }
+ call.append(", ");
+ call.append(p.name);
+ builder.addParameter(p);
+ }
+ call.append(")");
+ builder.addStatement(call.toString(), ClassName.get(opClass));
+ return new OpMethod(
+ methodName, opClass, endpointMethod, describeByClass, deprecated, builder.build());
+ }
+
+ protected Javadoc buildOpMethodJavadoc(
+ TypeElement opClass, ExecutableElement endpointMethod, boolean copyClassDescription) {
+ Javadoc methodJavadoc = parseJavadoc(endpointMethod);
+
+ Javadoc javadoc;
+
+ if (!copyClassDescription) {
+ javadoc = new Javadoc(methodJavadoc.getDescription());
+ } else {
+ javadoc = parseJavadoc(opClass);
+ }
+
+ // Copy all endpoint method tags to the description, except for the `scope` parameter which
+ // will be inferred by the Ops class
+ methodJavadoc
+ .getBlockTags()
+ .forEach(
+ t -> {
+ if (!(t.getTagName().equals("param")
+ && t.getName().map(s -> s.equals("scope")).orElse(false))) {
+ javadoc.addBlockTag(t);
+ }
+ });
+
+ return javadoc;
+ }
+
+ protected static Collection collectGroupOps(
+ OpsSpec ops, Multimap groupedMethods) {
+ Map groups = new HashMap<>();
+
+ // The `group` label added in the `@Operator` annotation has the same syntax as a package name,
+ // which (in most
+ // case) consists of a simple label but could also be a deeper tree, like `linalg.sparse`. In
+ // this case,
+ // the `LinalgSparseOps` group should be added as the `sparse` field of the `LinalgOps` group,
+ // and the latter
+ // should be added as the `linalg` field of the `Ops` root class.
+ groupedMethods
+ .keys()
+ .forEach(
+ group -> {
+ OpsSpec parentClass = ops;
+ int startPos = 0;
+ do {
+ int delimiterPos = group.indexOf('.', startPos);
+ String groupName = delimiterPos < 0 ? group : group.substring(0, delimiterPos);
+ OpsSpec groupOps = groups.get(groupName);
+
+ // Create spec for this group if we have not encountered it yet in our iteration
+ if (groupOps == null) {
+ String fieldName =
+ delimiterPos < 0
+ ? group.substring(startPos)
+ : group.substring(startPos, delimiterPos);
+ ClassName className =
+ ClassName.get(
+ "org.tensorflow.op",
+ CaseFormat.LOWER_UNDERSCORE.to(
+ CaseFormat.UPPER_CAMEL, groupName.replace('.', '_'))
+ + "Ops");
+ groupOps =
+ new OpsSpec(
+ parentClass,
+ groupName,
+ fieldName,
+ className,
+ groupedMethods.get(groupName));
+ parentClass.subGroups.add(groupOps);
+ groups.put(groupName, groupOps);
+ }
+ parentClass = groupOps;
+ startPos = delimiterPos + 1;
+ } while (startPos > 0);
+ });
+
+ return groups.values();
+ }
+
+ protected abstract T buildGroupClass(OpsSpec spec);
+
+ protected abstract T buildTopClass(OpsSpec spec);
+
+ protected static AnnotationMirror getAnnotationMirror(Element element, Name annotationName) {
+ for (AnnotationMirror am : element.getAnnotationMirrors()) {
+ if (((TypeElement) am.getAnnotationType().asElement())
+ .getQualifiedName()
+ .equals(annotationName)) {
+ return am;
+ }
+ }
+ return null;
+ }
+
+ protected static AnnotationValue getAnnotationElementValue(
+ String elementName, AnnotationMirror am) {
+ for (Map.Entry extends ExecutableElement, ? extends AnnotationValue> entry :
+ am.getElementValues().entrySet()) {
+ if (entry.getKey().getSimpleName().contentEquals(elementName)) {
+ return entry.getValue();
+ }
+ }
+ return null;
+ }
+
+ protected static String getAnnotationElementValueAsString(
+ String elementName, AnnotationMirror am) {
+ AnnotationValue value = getAnnotationElementValue(elementName, am);
+ return value != null ? value.getValue().toString() : "";
+ }
+
+ protected static boolean getAnnotationElementValueAsBoolean(
+ String elementName, AnnotationMirror am, boolean defaultValue) {
+ AnnotationValue value = getAnnotationElementValue(elementName, am);
+ return value != null ? Boolean.parseBoolean(value.toString()) : defaultValue;
+ }
+
+ protected Javadoc parseJavadoc(Element element) {
+ String docComment = elements.getDocComment(element);
+ JavadocComment javadocComment;
+ if (docComment != null) {
+ javadocComment = new JavadocComment(docComment);
+ } else {
+ javadocComment = new JavadocComment();
+ }
+ return javadocComment.parse();
+ }
+}
diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java
index 99277e8fe24..b07029a48e8 100644
--- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java
+++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java
@@ -15,53 +15,14 @@
*/
package org.tensorflow.processor.operator;
-import com.github.javaparser.ast.comments.JavadocComment;
-import com.github.javaparser.javadoc.Javadoc;
-import com.google.common.base.CaseFormat;
-import com.google.common.base.Strings;
-import com.google.common.collect.HashMultimap;
-import com.google.common.collect.Multimap;
-import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.FieldSpec;
import com.squareup.javapoet.JavaFile;
import com.squareup.javapoet.MethodSpec;
-import com.squareup.javapoet.ParameterSpec;
-import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;
import com.squareup.javapoet.TypeVariableName;
import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.Comparator;
-import java.util.HashMap;
import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.regex.Pattern;
-import javax.annotation.processing.AbstractProcessor;
-import javax.annotation.processing.Filer;
-import javax.annotation.processing.Messager;
-import javax.annotation.processing.ProcessingEnvironment;
-import javax.annotation.processing.RoundEnvironment;
-import javax.lang.model.SourceVersion;
-import javax.lang.model.element.AnnotationMirror;
-import javax.lang.model.element.AnnotationValue;
-import javax.lang.model.element.Element;
-import javax.lang.model.element.ExecutableElement;
import javax.lang.model.element.Modifier;
-import javax.lang.model.element.Name;
-import javax.lang.model.element.TypeElement;
-import javax.lang.model.element.TypeParameterElement;
-import javax.lang.model.element.VariableElement;
-import javax.lang.model.type.NoType;
-import javax.lang.model.type.TypeMirror;
-import javax.lang.model.type.TypeVariable;
-import javax.lang.model.util.ElementFilter;
-import javax.lang.model.util.Elements;
-import javax.lang.model.util.Types;
-import javax.tools.Diagnostic.Kind;
import org.tensorflow.Names;
/**
@@ -75,159 +36,10 @@
* "aggregating" API, and annotation processing does not permit modifying an already generated
* class.
*/
-public final class OperatorProcessor extends AbstractProcessor {
+public final class OperatorProcessor extends BaseOperatorProcessor {
@Override
- public SourceVersion getSupportedSourceVersion() {
- return SourceVersion.latest();
- }
-
- @Override
- public synchronized void init(ProcessingEnvironment processingEnv) {
- super.init(processingEnv);
- messager = processingEnv.getMessager();
- filer = processingEnv.getFiler();
- elements = processingEnv.getElementUtils();
- types = processingEnv.getTypeUtils();
- }
-
- @Override
- public boolean process(Set extends TypeElement> annotations, RoundEnvironment roundEnv) {
- // Nothing needs to be done at the end of all rounds.
- if (roundEnv.processingOver()) {
- return false;
- }
-
- // Nothing to look at in this round.
- if (annotations.size() == 0) {
- return false;
- }
-
- // We expect to be registered for exactly one annotation.
- if (annotations.size() != 1) {
- throw new IllegalStateException(
- "Unexpected - multiple annotations registered: " + annotations);
- }
- TypeElement annotation = annotations.iterator().next();
- Set extends Element> annotated = roundEnv.getElementsAnnotatedWith(annotation);
-
- // If there are no annotated elements, claim the annotation but do nothing.
- if (annotated.size() == 0) {
- return true;
- }
-
- // This processor has to aggregate all op classes in one round, as it generates a single Ops
- // API class which cannot be modified once generated. If we find an annotation after we've
- // generated our code, flag the location of each such class.
- if (hasRun) {
- for (Element e : annotated) {
- error(
- e,
- "The Operator processor has already processed @Operator annotated sources\n"
- + "and written out an Ops API. It cannot process additional @Operator sources.\n"
- + "One reason this can happen is if other annotation processors generate\n"
- + "new @Operator source files.");
- }
- return true;
- }
-
- // Collect all classes tagged with our annotation.
- Multimap groupedMethods = HashMultimap.create();
- if (!collectOpsMethods(roundEnv, groupedMethods, annotation)) {
- return true;
- }
-
- // Nothing to do when there are no tagged classes.
- if (groupedMethods.isEmpty()) {
- return true;
- }
-
- // Validate operator classes and generate Op API.
- writeApi(groupedMethods);
-
- hasRun = true;
- return true;
- }
-
- @Override
- public Set getSupportedAnnotationTypes() {
- return Collections.singleton("org.tensorflow.op.annotation.Operator");
- }
-
- private static class OpsSpec {
-
- private static final Comparator PARAMETER_SPEC_COMPARATOR =
- (o1, o2) -> {
- if (o1.parameters.size() > o2.parameters.size()) {
- return 1;
- }
- if (o1.parameters.size() < o2.parameters.size()) {
- return -1;
- }
- List firstParams = o1.parameters;
- List secondParams = o2.parameters;
- for (int i = 0; i < firstParams.size(); i++) {
- ParameterSpec first = firstParams.get(i);
- ParameterSpec second = secondParams.get(i);
- int compare = first.name.compareTo(second.name);
- if (compare != 0) {
- return compare;
- }
- }
- return 0;
- };
- private static final Comparator METHOD_SPEC_COMPARATOR =
- Comparator.comparing((MethodSpec m) -> m.name).thenComparing(PARAMETER_SPEC_COMPARATOR);
-
- final String groupName;
- final String fieldName;
- final ClassName className;
- final List methods;
- final List subGroups = new ArrayList<>();
-
- OpsSpec(
- String groupName, String fieldName, ClassName className, Collection methods) {
- this.groupName = groupName;
- this.fieldName = fieldName;
- this.className = className;
- this.methods = new ArrayList<>(methods);
- this.methods.sort(METHOD_SPEC_COMPARATOR);
- }
- }
-
- private static final Pattern JAVADOC_TAG_PATTERN =
- Pattern.compile("@(?:param|return|throws|exception|see|deprecated)\\s+.*");
-
- private static final String LICENSE =
- "Copyright 2020 The TensorFlow Authors. All Rights Reserved.\n"
- + "\n"
- + "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
- + "you may not use this file except in compliance with the License.\n"
- + "You may obtain a copy of the License at\n"
- + "\n"
- + " http://www.apache.org/licenses/LICENSE-2.0\n"
- + "\n"
- + "Unless required by applicable law or agreed to in writing, software\n"
- + "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
- + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
- + "See the License for the specific language governing permissions and\n"
- + "limitations under the License.\n"
- + "==============================================================================\n";
-
- private Filer filer;
- private Messager messager;
- private Elements elements;
- private Types types;
- private boolean hasRun = false;
-
- private void error(Element e, String message, Object... args) {
- if (args != null && args.length > 0) {
- message = String.format(message, args);
- }
- messager.printMessage(Kind.ERROR, message, e);
- }
-
- private void write(TypeSpec spec) {
+ protected void write(TypeSpec spec) {
try {
JavaFile.builder("org.tensorflow.op", spec)
.addFileComment(LICENSE)
@@ -240,213 +52,8 @@ private void write(TypeSpec spec) {
}
}
- private void writeApi(Multimap groupedMethods) {
- // Build tree of *Ops classes that needs to be generated by this processor. The 'Ops' class
- // resides at the root of the tree while other classes are nodes.
- OpsSpec ops = new OpsSpec(null, null, Names.Ops, groupedMethods.removeAll(""));
- Collection groupOps = collectGroupOps(ops, groupedMethods);
-
- write(buildTopClass(ops));
- groupOps.forEach(g -> write(buildGroupClass(g)));
- }
-
- private boolean collectOpsMethods(
- RoundEnvironment roundEnv,
- Multimap groupedMethods,
- TypeElement annotation) {
- boolean result = true;
- for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) {
- // @Operator can only apply to types, so e must be a TypeElement.
- if (!(e instanceof TypeElement)) {
- error(
- e,
- "@Operator can only be applied to classes, but this is a %s",
- e.getKind().toString());
- result = false;
- continue;
- }
- collectOpMethods(groupedMethods, (TypeElement) e, annotation);
- }
- return result;
- }
-
- private void collectOpMethods(
- Multimap groupedMethods, TypeElement opClass, TypeElement annotation) {
- boolean opClassDeprecated = opClass.getAnnotation(Deprecated.class) != null;
- AnnotationMirror operatorAnnot = getAnnotationMirror(opClass, annotation.getQualifiedName());
- if (operatorAnnot == null) {
- throw new IllegalArgumentException(
- "Annotation "
- + annotation.getSimpleName()
- + " not present on element "
- + opClass.getSimpleName());
- }
- String opGroup = getAnnotationElementValueAsString("group", operatorAnnot);
- String opName = getAnnotationElementValueAsString("name", operatorAnnot);
- if (Strings.isNullOrEmpty(opName)) {
- opName =
- CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, ClassName.get(opClass).simpleName());
- }
- // Build an endpoint for each method annotated with @Endpoint, which takes in parameter a scope
- // and, optionally, a list of arguments
- for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) {
- AnnotationMirror endpointAnnot =
- getAnnotationMirror(opMethod, elements.getName(Names.Endpoint.toString()));
- if (endpointAnnot != null) {
- if (!opMethod.getModifiers().containsAll(Arrays.asList(Modifier.STATIC, Modifier.PUBLIC))) {
- throw new IllegalArgumentException(
- "Endpoint " + opMethod + " of class " + opClass + " must be static and public");
- }
- if (opMethod.getParameters().isEmpty()
- || !((TypeElement) types.asElement(opMethod.getParameters().get(0).asType()))
- .getQualifiedName()
- .equals(elements.getName(Names.Scope.toString()))) {
- throw new IllegalArgumentException(
- "Endpoint "
- + opMethod
- + " of class "
- + opClass
- + " must take an instance of "
- + Names.Scope
- + " as its first parameter");
- }
- String endpointGroup = getAnnotationElementValueAsString("group", endpointAnnot);
- if (endpointGroup.isEmpty()) {
- endpointGroup = opGroup;
- }
- String endpointName = getAnnotationElementValueAsString("name", endpointAnnot);
- if (endpointName.isEmpty()) {
- endpointName = opName;
- }
- boolean describeByClass =
- getAnnotationElementValueAsBoolean("describeByClass", endpointAnnot, false);
- boolean deprecated = opMethod.getAnnotation(Deprecated.class) != null || opClassDeprecated;
- MethodSpec method =
- buildOpMethod(endpointName, opClass, opMethod, describeByClass, deprecated);
- groupedMethods.put(endpointGroup, method);
- }
- }
- }
-
- private MethodSpec buildOpMethod(
- String methodName,
- TypeElement opClass,
- ExecutableElement endpointMethod,
- boolean describeByClass,
- boolean deprecated) {
- MethodSpec.Builder builder =
- MethodSpec.methodBuilder(methodName)
- .addModifiers(Modifier.PUBLIC)
- .returns(TypeName.get(endpointMethod.getReturnType()))
- .varargs(endpointMethod.isVarArgs())
- .addJavadoc("$L", buildOpMethodJavadoc(opClass, endpointMethod, describeByClass));
-
- if (deprecated) {
- builder.addAnnotation(Deprecated.class);
- }
- for (TypeParameterElement tp : endpointMethod.getTypeParameters()) {
- TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType());
- builder.addTypeVariable(tvn);
- }
- for (TypeMirror thrownType : endpointMethod.getThrownTypes()) {
- builder.addException(TypeName.get(thrownType));
- }
- StringBuilder call = new StringBuilder();
- if (!NoType.class.isAssignableFrom(endpointMethod.getReturnType().getClass())) {
- call.append("return ");
- }
- call.append("$T.").append(endpointMethod.getSimpleName()).append("(scope");
- boolean first = true;
- for (VariableElement param : endpointMethod.getParameters()) {
- ParameterSpec p = ParameterSpec.get(param);
- if (first) {
- first = false;
- continue;
- }
- call.append(", ");
- call.append(p.name);
- builder.addParameter(p);
- }
- call.append(")");
- builder.addStatement(call.toString(), ClassName.get(opClass));
- return builder.build();
- }
-
- private String buildOpMethodJavadoc(
- TypeElement opClass, ExecutableElement endpointMethod, boolean copyClassDescription) {
- Javadoc methodJavadoc = parseJavadoc(endpointMethod);
-
- Javadoc javadoc;
-
- if (!copyClassDescription) {
- javadoc = new Javadoc(methodJavadoc.getDescription());
- } else {
- javadoc = parseJavadoc(opClass);
- }
-
- // Copy all endpoint method tags to the description, except for the `scope` parameter which
- // will be inferred by the Ops class
- methodJavadoc
- .getBlockTags()
- .forEach(
- t -> {
- if (!(t.getTagName().equals("param")
- && t.getName().map(s -> s.equals("scope")).orElse(false))) {
- javadoc.addBlockTag(t);
- }
- });
-
- return javadoc.toText();
- }
-
- private static Collection collectGroupOps(
- OpsSpec ops, Multimap groupedMethods) {
- Map groups = new HashMap<>();
-
- // The `group` label added in the `@Operator` annotation has the same syntax as a package name,
- // which (in most
- // case) consists of a simple label but could also be a deeper tree, like `linalg.sparse`. In
- // this case,
- // the `LinalgSparseOps` group should be added as the `sparse` field of the `LinalgOps` group,
- // and the latter
- // should be added as the `linalg` field of the `Ops` root class.
- groupedMethods
- .keys()
- .forEach(
- group -> {
- OpsSpec parentClass = ops;
- int startPos = 0;
- do {
- int delimiterPos = group.indexOf('.', startPos);
- String groupName = delimiterPos < 0 ? group : group.substring(0, delimiterPos);
- OpsSpec groupOps = groups.get(groupName);
-
- // Create spec for this group if we have not encountered it yet in our iteration
- if (groupOps == null) {
- String fieldName =
- delimiterPos < 0
- ? group.substring(startPos)
- : group.substring(startPos, delimiterPos);
- ClassName className =
- ClassName.get(
- "org.tensorflow.op",
- CaseFormat.LOWER_UNDERSCORE.to(
- CaseFormat.UPPER_CAMEL, groupName.replace('.', '_'))
- + "Ops");
- groupOps =
- new OpsSpec(groupName, fieldName, className, groupedMethods.get(groupName));
- parentClass.subGroups.add(groupOps);
- groups.put(groupName, groupOps);
- }
- parentClass = groupOps;
- startPos = delimiterPos + 1;
- } while (startPos > 0);
- });
-
- return groups.values();
- }
-
- private static TypeSpec buildGroupClass(OpsSpec spec) {
+ @Override
+ protected TypeSpec buildGroupClass(OpsSpec spec) {
// System.out.println("Generating " + spec.className + " class");
MethodSpec.Builder ctorBuilder =
@@ -464,7 +71,7 @@ private static TypeSpec buildGroupClass(OpsSpec spec) {
spec.groupName,
Names.Op,
Names.Ops)
- .addMethods(spec.methods);
+ .addMethods(spec.javaMethods());
MethodSpec.Builder opsBuilder =
MethodSpec.methodBuilder("ops")
@@ -490,7 +97,8 @@ private static TypeSpec buildGroupClass(OpsSpec spec) {
return builder.build();
}
- private static TypeSpec buildTopClass(OpsSpec spec) {
+ @Override
+ protected TypeSpec buildTopClass(OpsSpec spec) {
// System.out.println("Generating " + spec.className + " class");
MethodSpec.Builder ctorBuilder =
@@ -500,6 +108,7 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
TypeSpec.Builder opsBuilder =
TypeSpec.classBuilder("Ops")
+ .addSuperinterface(Names.WithOps)
.addModifiers(Modifier.PUBLIC, Modifier.FINAL)
.addJavadoc(
"An API for building operations as {@link $T Op}s\n
\n"
@@ -532,22 +141,28 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
+ "}\n",
Names.Op,
Names.Operator)
- .addMethods(spec.methods);
+ .addMethods(spec.javaMethods());
addGroupFields(opsBuilder, ctorBuilder, spec.subGroups, true);
opsBuilder.addMethod(ctorBuilder.build());
+ opsBuilder.addMethod(
+ MethodSpec.methodBuilder("tf")
+ .addModifiers(Modifier.PUBLIC)
+ .addAnnotation(Override.class)
+ .returns(Names.Ops)
+ .addStatement("return this")
+ .build());
+
opsBuilder.addMethod(
MethodSpec.methodBuilder("withSubScope")
.addModifiers(Modifier.PUBLIC)
+ .addAnnotation(Override.class)
.addParameter(Names.String, "childScopeName")
.returns(Names.Ops)
.addStatement("return new $T(scope.withSubScope(childScopeName))", Names.Ops)
- .addJavadoc(
- "Returns an API that builds operations with the provided name prefix.\n"
- + "\n@see {@link $T#withSubScope(String)}\n",
- Names.Scope)
+ .addJavadoc("{@inheritDoc}")
.build());
String initScopeComment =
@@ -586,37 +201,31 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
opsBuilder.addMethod(
MethodSpec.methodBuilder("withName")
.addModifiers(Modifier.PUBLIC)
+ .addAnnotation(Override.class)
.addParameter(Names.String, "opName")
.returns(Names.Ops)
.addStatement("return new Ops(scope.withName(opName))")
- .addJavadoc(
- "Returns an API that uses the provided name for an op.\n\n"
- + "@see {@link $T#withName(String)}\n",
- Names.Scope)
+ .addJavadoc("{@inheritDoc}")
.build());
opsBuilder.addMethod(
MethodSpec.methodBuilder("withDevice")
.addModifiers(Modifier.PUBLIC)
+ .addAnnotation(Override.class)
.addParameter(Names.DeviceSpec, "deviceSpec")
.returns(Names.Ops)
.addStatement("return new Ops(scope.withDevice(deviceSpec))")
- .addJavadoc(
- "Returns an API that places the created operations on the device(s) matching the provided spec.\n\n"
- + "@see {@link $T#withDevice(DeviceSpec)}\n",
- Names.Scope)
+ .addJavadoc("{@inheritDoc}")
.build());
opsBuilder.addMethod(
MethodSpec.methodBuilder("withControlDependencies")
.addModifiers(Modifier.PUBLIC)
+ .addAnnotation(Override.class)
.addParameter(Names.IterableOp, "controls")
.returns(Names.Ops)
.addStatement("return new Ops(scope.withControlDependencies(controls))")
- .addJavadoc(
- "Returns an API that adds operations to the graph with the provided control dependencies.\n\n"
- + "@see {@link $T#withControlDependencies(Iterable>)}\n",
- Names.Scope)
+ .addJavadoc("{@inheritDoc}")
.build());
opsBuilder.addMethod(
@@ -700,6 +309,8 @@ private static void addGroupFields(
boolean isTopClass) {
groups.forEach(
group -> {
+ System.out.println(
+ "Adding field in " + classBuilder.build().name + ": " + group.fieldName);
classBuilder.addField(
FieldSpec.builder(group.className, group.fieldName)
.addModifiers(Modifier.PUBLIC, Modifier.FINAL)
@@ -712,48 +323,4 @@ private static void addGroupFields(
.build();
});
}
-
- private static AnnotationMirror getAnnotationMirror(Element element, Name annotationName) {
- for (AnnotationMirror am : element.getAnnotationMirrors()) {
- if (((TypeElement) am.getAnnotationType().asElement())
- .getQualifiedName()
- .equals(annotationName)) {
- return am;
- }
- }
- return null;
- }
-
- private static AnnotationValue getAnnotationElementValue(
- String elementName, AnnotationMirror am) {
- for (Map.Entry extends ExecutableElement, ? extends AnnotationValue> entry :
- am.getElementValues().entrySet()) {
- if (entry.getKey().getSimpleName().contentEquals(elementName)) {
- return entry.getValue();
- }
- }
- return null;
- }
-
- private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) {
- AnnotationValue value = getAnnotationElementValue(elementName, am);
- return value != null ? value.getValue().toString() : "";
- }
-
- private static boolean getAnnotationElementValueAsBoolean(
- String elementName, AnnotationMirror am, boolean defaultValue) {
- AnnotationValue value = getAnnotationElementValue(elementName, am);
- return value != null ? Boolean.parseBoolean(value.toString()) : defaultValue;
- }
-
- private Javadoc parseJavadoc(Element element) {
- String docComment = elements.getDocComment(element);
- JavadocComment javadocComment;
- if (docComment != null) {
- javadocComment = new JavadocComment(docComment);
- } else {
- javadocComment = new JavadocComment();
- }
- return javadocComment.parse();
- }
}
diff --git a/tensorflow-core/tensorflow-core-platform-mkl-gpu/pom.xml b/tensorflow-core/tensorflow-core-platform-mkl-gpu/pom.xml
index 812a53d129b..fed50858f48 100644
--- a/tensorflow-core/tensorflow-core-platform-mkl-gpu/pom.xml
+++ b/tensorflow-core/tensorflow-core-platform-mkl-gpu/pom.xml
@@ -22,7 +22,7 @@
org.tensorflowtensorflow-core
- 0.4.0-SNAPSHOT
+ 0.5.0-SNAPSHOTtensorflow-core-platform-mkl-gpuTensorFlow Core API Library Platform MKL GPU
diff --git a/tensorflow-core/tensorflow-core-platform-mkl/pom.xml b/tensorflow-core/tensorflow-core-platform-mkl/pom.xml
index 9800ff1cb95..0c855068865 100644
--- a/tensorflow-core/tensorflow-core-platform-mkl/pom.xml
+++ b/tensorflow-core/tensorflow-core-platform-mkl/pom.xml
@@ -22,7 +22,7 @@
org.tensorflowtensorflow-core
- 0.4.0-SNAPSHOT
+ 0.5.0-SNAPSHOTtensorflow-core-platform-mklTensorFlow Core API Library Platform MKL
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java
index ea73f764a38..e258330df70 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java
@@ -66,8 +66,8 @@ public Operand call(Ops tf, Operand dims, Class type) {
if (shape.numDimensions() != 2) {
throw new IllegalArgumentException("2D matrix required, got " + shape.numDimensions());
}
- boolean isSquare = shape.size(0) == shape.size(1);
- long diagSize = Math.min(shape.size(0), shape.size(1));
+ boolean isSquare = shape.get(0) == shape.get(1);
+ long diagSize = Math.min(shape.get(0), shape.get(1));
Shape diagShape = Shape.of(diagSize);
Operand op;
@@ -79,8 +79,8 @@ public Operand call(Ops tf, Operand dims, Class type) {
tf.linalg.matrixDiag(
diagOnes,
tf.constant(0), // don't cast here, expecting TInt32
- tf.constant((int) shape.size(0)),
- tf.constant((int) shape.size(1)),
+ tf.constant((int) shape.get(0)),
+ tf.constant((int) shape.get(1)),
zero);
} else {
Operand zeroMatrix = tf.zeros(dims, type);
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java
index 240d915f97f..a24b791fd47 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java
@@ -91,8 +91,8 @@ public Operand call(Ops tf, Operand dims, Class type) {
}
long numRows = 1;
int i = 0;
- for (; i < dimsShape.numDimensions() - 1; i++) numRows *= dimsShape.size(i);
- long numCols = dimsShape.size(i);
+ for (; i < dimsShape.numDimensions() - 1; i++) numRows *= dimsShape.get(i);
+ long numCols = dimsShape.get(i);
Shape flatShape = Shape.of(Math.max(numRows, numCols), Math.min(numRows, numCols));
long[] seeds = {seed, 0};
Operand op =
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java
index d23059b88fd..f01ce2e75e0 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java
@@ -572,7 +572,7 @@ public static Operand sparseCategoricalCrossentropy(
tf.reshape(
predictions,
tf.constant(
- new long[] {-1L, predictionsShape.size(predictionsShape.numDimensions() - 1)}));
+ new long[] {-1L, predictionsShape.get(predictionsShape.numDimensions() - 1)}));
}
Operand loss = ftf.nn.sparseSoftmaxCrossEntropyWithLogits(iLabels, predictions);
@@ -648,7 +648,7 @@ private static Operand smoothCategoricalLabels(
Operand smoothing = cast(tf, tf.constant(labelSmoothing), labelType);
Shape labelsShape = labels.shape();
int numDims = labelsShape.numDimensions();
- Operand numClasses = cast(tf, tf.constant(labelsShape.size(numDims - 1)), labelType);
+ Operand numClasses = cast(tf, tf.constant(labelsShape.get(numDims - 1)), labelType);
Operand oneMinusSmoothing = cast(tf, tf.constant(1.f - labelSmoothing), labelType);
return tf.math.add(tf.math.mul(labels, oneMinusSmoothing), tf.math.div(smoothing, numClasses));
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java
index f6b0de71b0d..11c838277a4 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java
@@ -14,6 +14,10 @@
=======================================================================*/
package org.tensorflow.framework.losses.impl;
+import static org.tensorflow.framework.utils.CastHelper.cast;
+
+import java.util.Arrays;
+import java.util.Collections;
import org.tensorflow.Operand;
import org.tensorflow.framework.losses.Reduction;
import org.tensorflow.ndarray.Shape;
@@ -26,11 +30,6 @@
import org.tensorflow.types.TInt32;
import org.tensorflow.types.family.TNumber;
-import java.util.Arrays;
-import java.util.Collections;
-
-import static org.tensorflow.framework.utils.CastHelper.cast;
-
/**
* These are helper methods for Losses and Metrics and will be module private when Java modularity
* is applied to TensorFlow Java. These methods should not be used outside of the losses and metrics
@@ -101,7 +100,7 @@ public static LossTuple squeezeOrExpandDimensions(
long labelsRank = labelsShape.numDimensions();
if (labelsRank != Shape.UNKNOWN_SIZE && predictionsRank != Shape.UNKNOWN_SIZE) {
// Use static rank for 'label' and 'prediction'.
- if (predictionsRank - labelsRank != 1 || predictionsShape.size(-1) == 1) {
+ if (predictionsRank - labelsRank != 1 || predictionsShape.get(-1) == 1) {
lossTuple = removeSqueezableDimensions(tf, labels, predictions);
}
} else { // use dynamic rank
@@ -213,9 +212,9 @@ public static LossTuple removeSqueezableDimensions(
if (predictionsRank != Shape.UNKNOWN_SIZE || labelsRank != Shape.UNKNOWN_SIZE) {
// Use static rank.
int rankDiff = predictionsRank - labelsRank;
- if (rankDiff == expectedRankDiff + 1 && Shape.isCompatible(predictionsShape.size(-1), 1)) {
+ if (rankDiff == expectedRankDiff + 1 && Shape.isCompatible(predictionsShape.get(-1), 1)) {
predictions = tf.squeeze(predictions);
- } else if (rankDiff == expectedRankDiff - 1 && Shape.isCompatible(labelsShape.size(-1), 1)) {
+ } else if (rankDiff == expectedRankDiff - 1 && Shape.isCompatible(labelsShape.get(-1), 1)) {
labels = tf.squeeze(labels);
}
return new LossTuple<>(labels, predictions);
@@ -224,7 +223,7 @@ public static LossTuple removeSqueezableDimensions(
// TODO: hold for lazy select feature,
// Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels));
- if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) {
+ if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.get(-1), 1)) {
/*
* TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze
* predictions = tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ),
@@ -232,7 +231,7 @@ public static LossTuple removeSqueezableDimensions(
*/
predictions = tf.squeeze(predictions, Squeeze.axis(Collections.singletonList(-1L)));
}
- if (labelsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(labelsShape.size(-1), 1)) {
+ if (labelsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(labelsShape.get(-1), 1)) {
/*
* TODO, if we ever get a select that does lazy evaluation labels = tf.select(
* tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), tf.squeeze(labels,
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
index 70a81da8d1e..d9e96081233 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
@@ -110,8 +110,8 @@ public static Op assertBroadcastable(
}
for (int i = 0; i < valuesRankStatic; i++) {
- if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i)
- && weightsShapeStatic.size(i) != 1) {
+ if (valuesShapeStatic.get(i) != weightsShapeStatic.get(i)
+ && weightsShapeStatic.get(i) != 1) {
throw new NotBroadcastableException(
String.format(
"%s Mismatch at dim %d. values.shape=%s weights.shape=%s.",
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java
index b1e2ce6c928..8bcd38bb7d6 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java
@@ -97,8 +97,8 @@ public static Operand sigmoidCrossEntropyWithLogits(
private static boolean isCompatible(Shape shape, Shape other) {
if (shape.numDimensions() != other.numDimensions()) return false;
for (int i = 0; i < shape.numDimensions(); i++) {
- long aShapeDim = shape.size(i);
- long bShapeDim = other.size(i);
+ long aShapeDim = shape.get(i);
+ long bShapeDim = other.get(i);
if (aShapeDim == bShapeDim
|| (aShapeDim == Shape.UNKNOWN_SIZE || bShapeDim == Shape.UNKNOWN_SIZE)) {
continue;
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java
index a95110c9a96..5e3ed52a220 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java
@@ -21,6 +21,7 @@
import org.tensorflow.types.TInt64;
import org.tensorflow.types.family.TNumber;
+@Operator(group = "nn")
public class SoftmaxCrossEntropyWithLogits {
/**
@@ -137,10 +138,10 @@ public static Operand softmaxCrossEntr
axis = shape.numDimensions() + axis;
}
for (int i = 0; i < axis; i++) {
- newArray[i] = shape.size(i);
+ newArray[i] = shape.get(i);
}
for (int i = axis + 1; i < shape.numDimensions(); i++) {
- newArray[i - 1] = shape.size(i);
+ newArray[i - 1] = shape.get(i);
}
cost = Reshape.create(scope, cost, Constant.vectorOf(scope, newArray));
}
@@ -165,7 +166,7 @@ private static Operand flattenOuterDims(Scope scope, Oper
long product = 1L;
boolean productValid = true;
for (int i = ndims - 2; i >= 0; i--) {
- long d = shape.size(i);
+ long d = shape.get(i);
if (d == Shape.UNKNOWN_SIZE) {
productValid = false;
break;
@@ -173,7 +174,7 @@ private static Operand flattenOuterDims(Scope scope, Oper
product *= d;
}
if (productValid) {
- return Reshape.create(scope, logits, Constant.arrayOf(scope, product, shape.size(-1)));
+ return Reshape.create(scope, logits, Constant.arrayOf(scope, product, shape.get(-1)));
}
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java
index 5299efcce22..3c196641878 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java
@@ -19,6 +19,7 @@
import org.tensorflow.types.TInt32;
import org.tensorflow.types.family.TNumber;
+@Operator(group = "nn")
public class SparseSoftmaxCrossEntropyWithLogits {
/**
@@ -139,7 +140,7 @@ Operand sparseSoftmaxCrossEntropyWithLogits(
}
// Reshape logits to 2 dims, labels to 1 dim.
- long numClassses = logitsShape.size(-1);
+ long numClassses = logitsShape.get(-1);
preciseLogits = Reshape.create(scope, preciseLogits, Constant.arrayOf(scope, -1L, numClassses));
labels = Reshape.create(scope, labels, Constant.scalarOf(scope, -1));
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java
index c0c0f12fbf9..f9842e628a0 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java
@@ -14,12 +14,11 @@
=======================================================================*/
package org.tensorflow.framework.utils;
-import org.tensorflow.ndarray.*;
-
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
+import org.tensorflow.ndarray.*;
// TODO used in the Callbacks, this should be a part of NDArray?
@@ -75,7 +74,7 @@ private static long[] getCoordinates(Shape shape, long index) {
int numDims = shape.numDimensions();
int i = numDims - 1;
for (; i >= 0; i--) {
- long size = shape.size(i);
+ long size = shape.get(i);
long mod = index % size;
coordinates[i] = mod;
index -= mod;
@@ -676,7 +675,7 @@ public static FloatNdArray sum(FloatNdArray a, int axis, boolean keepDims) {
int nDims = shape.numDimensions();
int xis = nDims - 1 - axis;
long totalSize = shape.size();
- long axisSize = shape.size(xis);
+ long axisSize = shape.get(xis);
final float[] sums = new float[(int) axisSize];
a.scalars()
@@ -767,7 +766,7 @@ public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) {
int nDims = shape.numDimensions();
int xis = nDims - 1 - axis;
long totalSize = shape.size();
- long axisSize = shape.size(xis);
+ long axisSize = shape.get(xis);
final double[] sums = new double[(int) axisSize];
a.scalars()
diff --git a/tensorflow-kotlin-parent/.editorconfig b/tensorflow-kotlin-parent/.editorconfig
new file mode 100644
index 00000000000..c5d853001f9
--- /dev/null
+++ b/tensorflow-kotlin-parent/.editorconfig
@@ -0,0 +1,93 @@
+# This .editorconfig section approximates ktfmt's formatting rules. You can include it in an
+# existing .editorconfig file or use it standalone by copying it to /.editorconfig
+# and making sure your editor is set to read settings from .editorconfig files.
+#
+# It includes editor-specific config options for IntelliJ IDEA.
+#
+# If any option is wrong, PR are welcome
+
+[{*.kt,*.kts}]
+indent_style = space
+insert_final_newline = true
+max_line_length = 100
+indent_size = 2
+ij_continuation_indent_size = 4
+ij_java_names_count_to_use_import_on_demand = 9999
+ij_kotlin_align_in_columns_case_branch = false
+ij_kotlin_align_multiline_binary_operation = false
+ij_kotlin_align_multiline_extends_list = false
+ij_kotlin_align_multiline_method_parentheses = false
+ij_kotlin_align_multiline_parameters = true
+ij_kotlin_align_multiline_parameters_in_calls = false
+ij_kotlin_allow_trailing_comma = true
+ij_kotlin_allow_trailing_comma_on_call_site = true
+ij_kotlin_assignment_wrap = normal
+ij_kotlin_blank_lines_after_class_header = 0
+ij_kotlin_blank_lines_around_block_when_branches = 0
+ij_kotlin_blank_lines_before_declaration_with_comment_or_annotation_on_separate_line = 1
+ij_kotlin_block_comment_at_first_column = true
+ij_kotlin_call_parameters_new_line_after_left_paren = true
+ij_kotlin_call_parameters_right_paren_on_new_line = false
+ij_kotlin_call_parameters_wrap = on_every_item
+ij_kotlin_catch_on_new_line = false
+ij_kotlin_class_annotation_wrap = split_into_lines
+ij_kotlin_code_style_defaults = KOTLIN_OFFICIAL
+ij_kotlin_continuation_indent_for_chained_calls = true
+ij_kotlin_continuation_indent_for_expression_bodies = true
+ij_kotlin_continuation_indent_in_argument_lists = true
+ij_kotlin_continuation_indent_in_elvis = false
+ij_kotlin_continuation_indent_in_if_conditions = false
+ij_kotlin_continuation_indent_in_parameter_lists = false
+ij_kotlin_continuation_indent_in_supertype_lists = false
+ij_kotlin_else_on_new_line = false
+ij_kotlin_enum_constants_wrap = off
+ij_kotlin_extends_list_wrap = normal
+ij_kotlin_field_annotation_wrap = split_into_lines
+ij_kotlin_finally_on_new_line = false
+ij_kotlin_if_rparen_on_new_line = false
+ij_kotlin_import_nested_classes = false
+ij_kotlin_insert_whitespaces_in_simple_one_line_method = true
+ij_kotlin_keep_blank_lines_before_right_brace = 2
+ij_kotlin_keep_blank_lines_in_code = 2
+ij_kotlin_keep_blank_lines_in_declarations = 2
+ij_kotlin_keep_first_column_comment = true
+ij_kotlin_keep_indents_on_empty_lines = false
+ij_kotlin_keep_line_breaks = true
+ij_kotlin_lbrace_on_next_line = false
+ij_kotlin_line_comment_add_space = false
+ij_kotlin_line_comment_at_first_column = true
+ij_kotlin_method_annotation_wrap = split_into_lines
+ij_kotlin_method_call_chain_wrap = normal
+ij_kotlin_method_parameters_new_line_after_left_paren = true
+ij_kotlin_method_parameters_right_paren_on_new_line = true
+ij_kotlin_method_parameters_wrap = on_every_item
+ij_kotlin_name_count_to_use_star_import = 9999
+ij_kotlin_name_count_to_use_star_import_for_members = 9999
+ij_kotlin_parameter_annotation_wrap = off
+ij_kotlin_space_after_comma = true
+ij_kotlin_space_after_extend_colon = true
+ij_kotlin_space_after_type_colon = true
+ij_kotlin_space_before_catch_parentheses = true
+ij_kotlin_space_before_comma = false
+ij_kotlin_space_before_extend_colon = true
+ij_kotlin_space_before_for_parentheses = true
+ij_kotlin_space_before_if_parentheses = true
+ij_kotlin_space_before_lambda_arrow = true
+ij_kotlin_space_before_type_colon = false
+ij_kotlin_space_before_when_parentheses = true
+ij_kotlin_space_before_while_parentheses = true
+ij_kotlin_spaces_around_additive_operators = true
+ij_kotlin_spaces_around_assignment_operators = true
+ij_kotlin_spaces_around_equality_operators = true
+ij_kotlin_spaces_around_function_type_arrow = true
+ij_kotlin_spaces_around_logical_operators = true
+ij_kotlin_spaces_around_multiplicative_operators = true
+ij_kotlin_spaces_around_range = false
+ij_kotlin_spaces_around_relational_operators = true
+ij_kotlin_spaces_around_unary_operator = false
+ij_kotlin_spaces_around_when_arrow = true
+ij_kotlin_variable_annotation_wrap = off
+ij_kotlin_while_on_new_line = false
+ij_kotlin_wrap_elvis_expressions = 1
+ij_kotlin_wrap_expression_body_functions = 1
+ij_kotlin_wrap_first_method_in_call_chain = false
\ No newline at end of file
diff --git a/tensorflow-kotlin-parent/README.md b/tensorflow-kotlin-parent/README.md
new file mode 100644
index 00000000000..c2c15eebf00
--- /dev/null
+++ b/tensorflow-kotlin-parent/README.md
@@ -0,0 +1,7 @@
+# Kotlin API
+
+This is the home of the Kotlin API for TensorFlow Java. The API lives in `tensorflow-core-api`, and uses the annotation processor in `tensorflow-core-generator`.
+
+There is no framework wrapper yet, as most of the framework classes work fine from Kotlin, but if there is a need one could be addded.
+
+For contributing guidelines, see [CONTRIBUTING.md](../CONTRIBUTING.md#kotlin-api).
diff --git a/tensorflow-kotlin-parent/pom.xml b/tensorflow-kotlin-parent/pom.xml
new file mode 100644
index 00000000000..a4997623eb5
--- /dev/null
+++ b/tensorflow-kotlin-parent/pom.xml
@@ -0,0 +1,107 @@
+
+
+
+ 4.0.0
+
+
+ org.tensorflow
+ tensorflow-java
+ 0.5.0-SNAPSHOT
+
+ tensorflow-kotlin-parent
+ pom
+
+ TensorFlow Kotlin Parent
+ Parent POM of TensorFlow Kotlin artifacts
+
+
+ tensorflow-kotlin-generator
+ tensorflow-core-kotlin
+ tensorflow-framework-kotlin
+ tensorflow-kotlin
+ tensorflow-kotlin-jupyter
+ tensorflow-core-kotlin-jupyter
+
+
+
+
+ org.jetbrains.kotlin
+ kotlin-stdlib-jdk8
+ ${kotlin.version}
+
+
+
+
+ 1.6.10
+ 0.11.0-40
+ 0.30
+ 1.8
+
+
+
+
+ jdk11
+
+ 11
+
+
+
+
+
+
+
+ org.jetbrains.kotlin
+ kotlin-maven-plugin
+ ${kotlin.version}
+
+ ${kotlin.jvmTarget}
+
+
+
+
+ compile
+
+ compile
+
+
+
+
+ test-compile
+
+ test-compile
+
+
+
+
+
+ com.diffplug.spotless
+ spotless-maven-plugin
+ ${spotless.version}
+
+
+
+ ${ktfmt.version}
+
+
+
+
+
+
+
+
diff --git a/tensorflow-kotlin-parent/tensorflow-core-kotlin-jupyter/pom.xml b/tensorflow-kotlin-parent/tensorflow-core-kotlin-jupyter/pom.xml
new file mode 100644
index 00000000000..1207e14aae6
--- /dev/null
+++ b/tensorflow-kotlin-parent/tensorflow-core-kotlin-jupyter/pom.xml
@@ -0,0 +1,78 @@
+
+
+
+ 4.0.0
+
+
+ org.tensorflow
+ tensorflow-kotlin-parent
+ 0.5.0-SNAPSHOT
+
+ tensorflow-core-kotlin-jupyter
+ jar
+
+ TensorFlow Core Kotlin Jupyter Integration
+ Kotlin Jupyter integration for tensorflow-core
+
+
+
+ ${project.version}
+
+
+
+
+ org.jetbrains.kotlinx
+ kotlin-jupyter-api
+ ${kotlin_jupyter.version}
+
+
+ org.tensorflow
+ tensorflow-core-kotlin
+ ${project.version}
+
+
+
+
+ ${project.basedir}/src/main/kotlin
+
+
+
+ org.jetbrains.kotlin
+ kotlin-maven-plugin
+ ${kotlin.version}
+
+
+
+
+ compile
+
+ compile
+
+
+
+ -Xopt-in=kotlin.contracts.ExperimentalContracts
+ -Xexplicit-api=strict
+
+
+
+
+
+
+
+
diff --git a/tensorflow-kotlin-parent/tensorflow-core-kotlin-jupyter/src/main/kotlin/org/tensorflow/jupyter/TensorflowKotlinCoreIntegration.kt b/tensorflow-kotlin-parent/tensorflow-core-kotlin-jupyter/src/main/kotlin/org/tensorflow/jupyter/TensorflowKotlinCoreIntegration.kt
new file mode 100644
index 00000000000..b1219be2b9f
--- /dev/null
+++ b/tensorflow-kotlin-parent/tensorflow-core-kotlin-jupyter/src/main/kotlin/org/tensorflow/jupyter/TensorflowKotlinCoreIntegration.kt
@@ -0,0 +1,44 @@
+/*
+ Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================
+
+*/
+package org.tensorflow.jupyter
+
+import org.jetbrains.kotlinx.jupyter.api.declare
+import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterIntegration
+import org.tensorflow.EagerSession
+import org.tensorflow.Operand
+import org.tensorflow.op.Op
+import org.tensorflow.op.kotlin.tf
+
+public class TensorflowKotlinCoreIntegration : JupyterIntegration() {
+ override fun Builder.onLoaded() {
+ import(
+ "org.tensorflow.*",
+ "org.tensorflow.op.*",
+ "org.tensorflow.op.kotlin.*",
+ "org.tensorflow.types.*",
+ "org.tensorflow.types.family.*",
+ "org.tensorflow.ndarray.*",
+ "org.tensorflow.ndarray.index.*")
+
+ render> { it.asOutput().toString() }
+ render { it.op().toString() }
+
+ // TODO add a implicit receiver of EagerSession.getDefault() instead
+ onLoaded { declare("tf" to EagerSession.getDefault().tf) }
+ }
+}
diff --git a/tensorflow-kotlin-parent/tensorflow-core-kotlin-jupyter/src/main/resources/META-INF/kotlin-jupyter-libraries/libraries.json b/tensorflow-kotlin-parent/tensorflow-core-kotlin-jupyter/src/main/resources/META-INF/kotlin-jupyter-libraries/libraries.json
new file mode 100644
index 00000000000..54d29d383b3
--- /dev/null
+++ b/tensorflow-kotlin-parent/tensorflow-core-kotlin-jupyter/src/main/resources/META-INF/kotlin-jupyter-libraries/libraries.json
@@ -0,0 +1,6 @@
+{
+ "definitions":[],
+ "producers": [
+ { "fqn" : "org.tensorflow.jupyter.TensorflowKotlinCoreIntegration" }
+ ]
+}
\ No newline at end of file
diff --git a/tensorflow-kotlin-parent/tensorflow-core-kotlin/pom.xml b/tensorflow-kotlin-parent/tensorflow-core-kotlin/pom.xml
new file mode 100644
index 00000000000..7ff643f2662
--- /dev/null
+++ b/tensorflow-kotlin-parent/tensorflow-core-kotlin/pom.xml
@@ -0,0 +1,163 @@
+
+
+
+ 4.0.0
+
+
+ org.tensorflow
+ tensorflow-kotlin-parent
+ 0.5.0-SNAPSHOT
+
+ tensorflow-core-kotlin
+ jar
+
+ TensorFlow Core Kotlin API Library
+ Kotlin API wrappers for the TensorFlow core Java library
+
+
+
+
+
+
+
+ org.tensorflow
+ tensorflow-core-api
+ ${project.version}
+
+
+ org.junit.jupiter
+ junit-jupiter-api
+ test
+
+
+ org.junit.jupiter
+ junit-jupiter-engine
+ test
+
+
+ org.openjdk.jmh
+ jmh-core
+ test
+
+
+ org.openjdk.jmh
+ jmh-generator-annprocess
+ test
+
+
+ org.jetbrains.kotlin
+ kotlin-test-junit5
+ ${kotlin.version}
+ test
+
+
+
+ org.tensorflow
+ tensorflow-core-platform${javacpp.platform.extension}
+ ${project.version}
+ test
+
+
+
+
+ ${project.basedir}/src/main/kotlin
+ ${project.basedir}/src/test/kotlin
+
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+ 3.0.0
+
+
+
+ add-gen-sources
+ generate-sources
+
+ add-source
+
+
+
+ ${project.basedir}/src/gen/annotations
+
+
+
+
+
+
+ org.jetbrains.kotlin
+ kotlin-maven-plugin
+ ${kotlin.version}
+
+
+
+
+ compile
+
+ compile
+
+
+
+ -Xopt-in=kotlin.contracts.ExperimentalContracts
+ -Xexplicit-api=strict
+
+
+
+
+
+ kapt
+
+ kapt
+
+
+
+ ${project.basedir}/src/main/kotlin
+ ${project.basedir}/../../tensorflow-core/tensorflow-core-api/src/gen/java
+ ${project.basedir}/../../tensorflow-core/tensorflow-core-api/src/gen/annotations
+ ${project.basedir}/../../tensorflow-core/tensorflow-core-api/src/main/java
+
+
+ org.tensorflow.processor.operator.KotlinOpsProcessor
+
+
+
+ org.tensorflow
+ tensorflow-kotlin-generator
+ ${project.version}
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+ 2.22.2
+
+
+
+
+
+
+
+
+
diff --git a/tensorflow-kotlin-parent/tensorflow-core-kotlin/src/gen/annotations/org/tensorflow/op/kotlin/AudioOps.kt b/tensorflow-kotlin-parent/tensorflow-core-kotlin/src/gen/annotations/org/tensorflow/op/kotlin/AudioOps.kt
new file mode 100644
index 00000000000..00608480fde
--- /dev/null
+++ b/tensorflow-kotlin-parent/tensorflow-core-kotlin/src/gen/annotations/org/tensorflow/op/kotlin/AudioOps.kt
@@ -0,0 +1,221 @@
+// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// ==============================================================================
+//
+// This class has been generated, DO NOT EDIT!
+//
+package org.tensorflow.op.kotlin
+
+import kotlin.Boolean
+import kotlin.Float
+import kotlin.Long
+import org.tensorflow.Operand
+import org.tensorflow.op.Scope
+import org.tensorflow.op.audio.AudioSpectrogram
+import org.tensorflow.op.audio.DecodeWav
+import org.tensorflow.op.audio.EncodeWav
+import org.tensorflow.op.audio.Mfcc
+import org.tensorflow.types.TFloat32
+import org.tensorflow.types.TInt32
+import org.tensorflow.types.TString
+
+/**
+ * An API for building `audio` operations as [Op][org.tensorflow.op.Op]s
+ *
+ * @see org.tensorflow.op.Ops
+ */
+public class AudioOps(
+ /**
+ * Get the parent [KotlinOps] object.
+ */
+ public val ops: KotlinOps
+) {
+ public val java: org.tensorflow.op.AudioOps = ops.java.audio
+
+ /**
+ * Returns the current [scope][Scope] of this API
+ */
+ public val scope: Scope = ops.scope
+
+ /**
+ * Produces a visualization of audio data over time.
+ * Spectrograms are a standard way of representing audio information as a series of
+ * slices of frequency information, one slice for each window of time. By joining
+ * these together into a sequence, they form a distinctive fingerprint of the sound
+ * over time.
+ *
+ * This op expects to receive audio data as an input, stored as floats in the range
+ * -1 to 1, together with a window width in samples, and a stride specifying how
+ * far to move the window between slices. From this it generates a three
+ * dimensional output. The first dimension is for the channels in the input, so a
+ * stereo audio input would have two here for example. The second dimension is time,
+ * with successive frequency slices. The third dimension has an amplitude value for
+ * each frequency during that time slice.
+ *
+ * This means the layout when converted and saved as an image is rotated 90 degrees
+ * clockwise from a typical spectrogram. Time is descending down the Y axis, and
+ * the frequency decreases from left to right.
+ *
+ * Each value in the result represents the square root of the sum of the real and
+ * imaginary parts of an FFT on the current window of samples. In this way, the
+ * lowest dimension represents the power of each frequency in the current window,
+ * and adjacent windows are concatenated in the next dimension.
+ *
+ * To get a more intuitive and visual look at what this operation does, you can run
+ * tensorflow/examples/wav_to_spectrogram to read in an audio file and save out the
+ * resulting spectrogram as a PNG image.
+ *
+ * @param input Float representation of audio data.
+ * @param windowSize How wide the input window is in samples. For the highest efficiency
+ * this should be a power of two, but other values are accepted.
+ * @param stride How widely apart the center of adjacent sample windows should be.
+ * @param options carries optional attribute values
+ * @return a new instance of AudioSpectrogram
+ * @see org.tensorflow.op.AudioOps.audioSpectrogram
+ * @param magnitudeSquared Sets the magnitudeSquared option.
+ *
+ * @param magnitudeSquared Whether to return the squared magnitude or just the
+ * magnitude. Using squared magnitude can avoid extra calculations.
+ * @return this Options instance.
+ */
+ public fun audioSpectrogram(
+ input: Operand,
+ windowSize: Long,
+ stride: Long,
+ magnitudeSquared: Boolean? = null
+ ): AudioSpectrogram = java.audioSpectrogram(
+ input,
+ windowSize,
+ stride,
+ *listOfNotNull(
+ magnitudeSquared?.let{ org.tensorflow.op.audio.AudioSpectrogram.magnitudeSquared(it) }
+ ).toTypedArray()
+ )
+
+ /**
+ * Decode a 16-bit PCM WAV file to a float tensor.
+ * The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float.
+ *
+ * When desired_channels is set, if the input contains fewer channels than this
+ * then the last channel will be duplicated to give the requested number, else if
+ * the input has more channels than requested then the additional channels will be
+ * ignored.
+ *
+ * If desired_samples is set, then the audio will be cropped or padded with zeroes
+ * to the requested length.
+ *
+ * The first output contains a Tensor with the content of the audio samples. The
+ * lowest dimension will be the number of channels, and the second will be the
+ * number of samples. For example, a ten-sample-long stereo WAV file should give an
+ * output shape of [10, 2].
+ *
+ * @param contents The WAV-encoded audio, usually from a file.
+ * @param options carries optional attribute values
+ * @return a new instance of DecodeWav
+ * @see org.tensorflow.op.AudioOps.decodeWav
+ * @param desiredChannels Sets the desiredChannels option.
+ *
+ * @param desiredChannels Number of sample channels wanted.
+ * @return this Options instance.
+ * @param desiredSamples Sets the desiredSamples option.
+ *
+ * @param desiredSamples Length of audio requested.
+ * @return this Options instance.
+ */
+ public fun decodeWav(
+ contents: Operand,
+ desiredChannels: Long? = null,
+ desiredSamples: Long? = null
+ ): DecodeWav = java.decodeWav(
+ contents,
+ *listOfNotNull(
+ desiredChannels?.let{ org.tensorflow.op.audio.DecodeWav.desiredChannels(it) },
+ desiredSamples?.let{ org.tensorflow.op.audio.DecodeWav.desiredSamples(it) }
+ ).toTypedArray()
+ )
+
+ /**
+ * Encode audio data using the WAV file format.
+ * This operation will generate a string suitable to be saved out to create a .wav
+ * audio file. It will be encoded in the 16-bit PCM format. It takes in float
+ * values in the range -1.0f to 1.0f, and any outside that value will be clamped to
+ * that range.
+ *
+ * `audio` is a 2-D float Tensor of shape `[length, channels]`.
+ * `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100).
+ *
+ * @param audio 2-D with shape `[length, channels]`.
+ * @param sampleRate Scalar containing the sample frequency.
+ * @return a new instance of EncodeWav
+ * @see org.tensorflow.op.AudioOps.encodeWav
+ */
+ public fun encodeWav(audio: Operand, sampleRate: Operand): EncodeWav =
+ java.encodeWav(
+ audio,
+ sampleRate
+ )
+
+ /**
+ * Transforms a spectrogram into a form that's useful for speech recognition.
+ * Mel Frequency Cepstral Coefficients are a way of representing audio data that's
+ * been effective as an input feature for machine learning. They are created by
+ * taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the
+ * higher frequencies that are less significant to the human ear. They have a long
+ * history in the speech recognition world, and
+ * https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
+ * is a good resource to learn more.
+ *
+ * @param spectrogram Typically produced by the Spectrogram op, with magnitude_squared
+ * set to true.
+ * @param sampleRate How many samples per second the source audio used.
+ * @param options carries optional attribute values
+ * @return a new instance of Mfcc
+ * @see org.tensorflow.op.AudioOps.mfcc
+ * @param upperFrequencyLimit Sets the upperFrequencyLimit option.
+ *
+ * @param upperFrequencyLimit The highest frequency to use when calculating the
+ * ceptstrum.
+ * @return this Options instance.
+ * @param lowerFrequencyLimit Sets the lowerFrequencyLimit option.
+ *
+ * @param lowerFrequencyLimit The lowest frequency to use when calculating the
+ * ceptstrum.
+ * @return this Options instance.
+ * @param filterbankChannelCount Sets the filterbankChannelCount option.
+ *
+ * @param filterbankChannelCount Resolution of the Mel bank used internally.
+ * @return this Options instance.
+ * @param dctCoefficientCount Sets the dctCoefficientCount option.
+ *
+ * @param dctCoefficientCount How many output channels to produce per time slice.
+ * @return this Options instance.
+ */
+ public fun mfcc(
+ spectrogram: Operand,
+ sampleRate: Operand,
+ upperFrequencyLimit: Float? = null,
+ lowerFrequencyLimit: Float? = null,
+ filterbankChannelCount: Long? = null,
+ dctCoefficientCount: Long? = null
+ ): Mfcc = java.mfcc(
+ spectrogram,
+ sampleRate,
+ *listOfNotNull(
+ upperFrequencyLimit?.let{ org.tensorflow.op.audio.Mfcc.upperFrequencyLimit(it) },
+ lowerFrequencyLimit?.let{ org.tensorflow.op.audio.Mfcc.lowerFrequencyLimit(it) },
+ filterbankChannelCount?.let{ org.tensorflow.op.audio.Mfcc.filterbankChannelCount(it) },
+ dctCoefficientCount?.let{ org.tensorflow.op.audio.Mfcc.dctCoefficientCount(it) }
+ ).toTypedArray()
+ )
+}
diff --git a/tensorflow-kotlin-parent/tensorflow-core-kotlin/src/gen/annotations/org/tensorflow/op/kotlin/BitwiseOps.kt b/tensorflow-kotlin-parent/tensorflow-core-kotlin/src/gen/annotations/org/tensorflow/op/kotlin/BitwiseOps.kt
new file mode 100644
index 00000000000..2ad2d734c3f
--- /dev/null
+++ b/tensorflow-kotlin-parent/tensorflow-core-kotlin/src/gen/annotations/org/tensorflow/op/kotlin/BitwiseOps.kt
@@ -0,0 +1,302 @@
+// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// ==============================================================================
+//
+// This class has been generated, DO NOT EDIT!
+//
+package org.tensorflow.op.kotlin
+
+import org.tensorflow.Operand
+import org.tensorflow.op.Scope
+import org.tensorflow.op.bitwise.BitwiseAnd
+import org.tensorflow.op.bitwise.BitwiseOr
+import org.tensorflow.op.bitwise.BitwiseXor
+import org.tensorflow.op.bitwise.Invert
+import org.tensorflow.op.bitwise.LeftShift
+import org.tensorflow.op.bitwise.RightShift
+import org.tensorflow.types.family.TNumber
+
+/**
+ * An API for building `bitwise` operations as [Op][org.tensorflow.op.Op]s
+ *
+ * @see org.tensorflow.op.Ops
+ */
+public class BitwiseOps(
+ /**
+ * Get the parent [KotlinOps] object.
+ */
+ public val ops: KotlinOps
+) {
+ public val java: org.tensorflow.op.BitwiseOps = ops.java.bitwise
+
+ /**
+ * Returns the current [scope][Scope] of this API
+ */
+ public val scope: Scope = ops.scope
+
+ /**
+ * Elementwise computes the bitwise AND of `x` and `y`.
+ * The result will have those bits set, that are set in both `x` and `y`. The
+ * computation is performed on the underlying representations of `x` and `y`.
+ *
+ * For example:
+ * ```
+ * import tensorflow as tf
+ * from tensorflow.python.ops import bitwise_ops
+ * dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64,
+ * tf.uint8, tf.uint16, tf.uint32, tf.uint64]
+ *
+ * for dtype in dtype_list:
+ * lhs = tf.constant([0, 5, 3, 14], dtype=dtype)
+ * rhs = tf.constant([5, 0, 7, 11], dtype=dtype)
+ * exp = tf.constant([0, 0, 3, 10], dtype=tf.float32)
+ *
+ * res = bitwise_ops.bitwise_and(lhs, rhs)
+ * tf.assert_equal(tf.cast(res, tf.float32), exp) # TRUE
+ *
+ * ```
+ *
+ * @param data type for `z` output
+ * @param x The x value
+ * @param y The y value
+ * @param data type for `BitwiseAnd` output and operands
+ * @return a new instance of BitwiseAnd
+ * @see org.tensorflow.op.BitwiseOps.bitwiseAnd
+ */
+ public fun bitwiseAnd(x: Operand, y: Operand): BitwiseAnd =
+ java.bitwiseAnd(
+ x,
+ y
+ )
+
+ /**
+ * Elementwise computes the bitwise OR of `x` and `y`.
+ * The result will have those bits set, that are set in `x`, `y` or both. The
+ * computation is performed on the underlying representations of `x` and `y`.
+ *
+ * For example:
+ * ```
+ * import tensorflow as tf
+ * from tensorflow.python.ops import bitwise_ops
+ * dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64,
+ * tf.uint8, tf.uint16, tf.uint32, tf.uint64]
+ *
+ * for dtype in dtype_list:
+ * lhs = tf.constant([0, 5, 3, 14], dtype=dtype)
+ * rhs = tf.constant([5, 0, 7, 11], dtype=dtype)
+ * exp = tf.constant([5, 5, 7, 15], dtype=tf.float32)
+ *
+ * res = bitwise_ops.bitwise_or(lhs, rhs)
+ * tf.assert_equal(tf.cast(res, tf.float32), exp) # TRUE
+ *
+ * ```
+ *
+ * @param data type for `z` output
+ * @param x The x value
+ * @param y The y value
+ * @param data type for `BitwiseOr` output and operands
+ * @return a new instance of BitwiseOr
+ * @see org.tensorflow.op.BitwiseOps.bitwiseOr
+ */
+ public fun bitwiseOr(x: Operand, y: Operand): BitwiseOr =
+ java.bitwiseOr(
+ x,
+ y
+ )
+
+ /**
+ * Elementwise computes the bitwise XOR of `x` and `y`.
+ * The result will have those bits set, that are different in `x` and `y`. The
+ * computation is performed on the underlying representations of `x` and `y`.
+ *
+ * For example:
+ * ```
+ * import tensorflow as tf
+ * from tensorflow.python.ops import bitwise_ops
+ * dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64,
+ * tf.uint8, tf.uint16, tf.uint32, tf.uint64]
+ *
+ * for dtype in dtype_list:
+ * lhs = tf.constant([0, 5, 3, 14], dtype=dtype)
+ * rhs = tf.constant([5, 0, 7, 11], dtype=dtype)
+ * exp = tf.constant([5, 5, 4, 5], dtype=tf.float32)
+ *
+ * res = bitwise_ops.bitwise_xor(lhs, rhs)
+ * tf.assert_equal(tf.cast(res, tf.float32), exp) # TRUE
+ *
+ * ```
+ *
+ * @param data type for `z` output
+ * @param x The x value
+ * @param y The y value
+ * @param data type for `BitwiseXor` output and operands
+ * @return a new instance of BitwiseXor
+ * @see org.tensorflow.op.BitwiseOps.bitwiseXor
+ */
+ public fun bitwiseXor(x: Operand, y: Operand): BitwiseXor =
+ java.bitwiseXor(
+ x,
+ y
+ )
+
+ /**
+ * Invert (flip) each bit of supported types; for example, type `uint8` value 01010101 becomes
+ * 10101010.
+ * Flip each bit of supported types. For example, type `int8` (decimal 2) binary 00000010
+ * becomes (decimal -3) binary 11111101.
+ * This operation is performed on each element of the tensor argument `x`.
+ *
+ * Example:
+ * ```
+ * import tensorflow as tf
+ * from tensorflow.python.ops import bitwise_ops
+ *
+ * # flip 2 (00000010) to -3 (11111101)
+ * tf.assert_equal(-3, bitwise_ops.invert(2))
+ *
+ * dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
+ * dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64]
+ *
+ * inputs = [0, 5, 3, 14]
+ * for dtype in dtype_list:
+ * # Because of issues with negative numbers, let's test this indirectly.
+ * # 1. invert(a) and a = 0
+ * # 2. invert(a) or a = invert(0)
+ * input_tensor = tf.constant([0, 5, 3, 14], dtype=dtype)
+ * not_a_and_a, not_a_or_a, not_0 = [bitwise_ops.bitwise_and(
+ * input_tensor, bitwise_ops.invert(input_tensor)),
+ * bitwise_ops.bitwise_or(
+ * input_tensor, bitwise_ops.invert(input_tensor)),
+ * bitwise_ops.invert(
+ * tf.constant(0, dtype=dtype))]
+ *
+ * expected = tf.constant([0, 0, 0, 0], dtype=tf.float32)
+ * tf.assert_equal(tf.cast(not_a_and_a, tf.float32), expected)
+ *
+ * expected = tf.cast([not_0] * 4, tf.float32)
+ * tf.assert_equal(tf.cast(not_a_or_a, tf.float32), expected)
+ *
+ * # For unsigned dtypes let's also check the result directly.
+ * if dtype.is_unsigned:
+ * inverted = bitwise_ops.invert(input_tensor)
+ * expected = tf.constant([dtype.max - x for x in inputs], dtype=tf.float32)
+ * tf.assert_equal(tf.cast(inverted, tf.float32), tf.cast(expected, tf.float32))
+ *
+ * ```
+ *
+ * @param data type for `y` output
+ * @param x The x value
+ * @param data type for `Invert` output and operands
+ * @return a new instance of Invert
+ * @see org.tensorflow.op.BitwiseOps.invert
+ */
+ public fun invert(x: Operand): Invert = java.invert(
+ x
+ )
+
+ /**
+ * Elementwise computes the bitwise left-shift of `x` and `y`.
+ * If `y` is negative, or greater than or equal to the width of `x` in bits the
+ * result is implementation defined.
+ *
+ * Example:
+ * ```
+ * import tensorflow as tf
+ * from tensorflow.python.ops import bitwise_ops
+ * import numpy as np
+ * dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64]
+ *
+ * for dtype in dtype_list:
+ * lhs = tf.constant([-1, -5, -3, -14], dtype=dtype)
+ * rhs = tf.constant([5, 0, 7, 11], dtype=dtype)
+ *
+ * left_shift_result = bitwise_ops.left_shift(lhs, rhs)
+ *
+ * print(left_shift_result)
+ *
+ * # This will print:
+ * # tf.Tensor([ -32 -5 -128 0], shape=(4,), dtype=int8)
+ * # tf.Tensor([ -32 -5 -384 -28672], shape=(4,), dtype=int16)
+ * # tf.Tensor([ -32 -5 -384 -28672], shape=(4,), dtype=int32)
+ * # tf.Tensor([ -32 -5 -384 -28672], shape=(4,), dtype=int64)
+ *
+ * lhs = np.array([-2, 64, 101, 32], dtype=np.int8)
+ * rhs = np.array([-1, -5, -3, -14], dtype=np.int8)
+ * bitwise_ops.left_shift(lhs, rhs)
+ * #
+ *
+ * ```
+ *
+ * @param data type for `z` output
+ * @param x The x value
+ * @param y The y value
+ * @param data type for `LeftShift` output and operands
+ * @return a new instance of LeftShift
+ * @see org.tensorflow.op.BitwiseOps.leftShift
+ */
+ public fun leftShift(x: Operand, y: Operand): LeftShift =
+ java.leftShift(
+ x,
+ y
+ )
+
+ /**
+ * Elementwise computes the bitwise right-shift of `x` and `y`.
+ * Performs a logical shift for unsigned integer types, and an arithmetic shift
+ * for signed integer types.
+ *
+ * If `y` is negative, or greater than or equal to than the width of `x` in bits
+ * the result is implementation defined.
+ *
+ * Example:
+ * ```
+ * import tensorflow as tf
+ * from tensorflow.python.ops import bitwise_ops
+ * import numpy as np
+ * dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64]
+ *
+ * for dtype in dtype_list:
+ * lhs = tf.constant([-1, -5, -3, -14], dtype=dtype)
+ * rhs = tf.constant([5, 0, 7, 11], dtype=dtype)
+ *
+ * right_shift_result = bitwise_ops.right_shift(lhs, rhs)
+ *
+ * print(right_shift_result)
+ *
+ * # This will print:
+ * # tf.Tensor([-1 -5 -1 -1], shape=(4,), dtype=int8)
+ * # tf.Tensor([-1 -5 -1 -1], shape=(4,), dtype=int16)
+ * # tf.Tensor([-1 -5 -1 -1], shape=(4,), dtype=int32)
+ * # tf.Tensor([-1 -5 -1 -1], shape=(4,), dtype=int64)
+ *
+ * lhs = np.array([-2, 64, 101, 32], dtype=np.int8)
+ * rhs = np.array([-1, -5, -3, -14], dtype=np.int8)
+ * bitwise_ops.right_shift(lhs, rhs)
+ * #
+ *
+ * ```
+ *
+ * @param data type for `z` output
+ * @param x The x value
+ * @param y The y value
+ * @param data type for `RightShift` output and operands
+ * @return a new instance of RightShift
+ * @see org.tensorflow.op.BitwiseOps.rightShift
+ */
+ public fun rightShift(x: Operand, y: Operand): RightShift =
+ java.rightShift(
+ x,
+ y
+ )
+}
diff --git a/tensorflow-kotlin-parent/tensorflow-core-kotlin/src/gen/annotations/org/tensorflow/op/kotlin/DataOps.kt b/tensorflow-kotlin-parent/tensorflow-core-kotlin/src/gen/annotations/org/tensorflow/op/kotlin/DataOps.kt
new file mode 100644
index 00000000000..6c911cfe33a
--- /dev/null
+++ b/tensorflow-kotlin-parent/tensorflow-core-kotlin/src/gen/annotations/org/tensorflow/op/kotlin/DataOps.kt
@@ -0,0 +1,3341 @@
+// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// ==============================================================================
+//
+// This class has been generated, DO NOT EDIT!
+//
+package org.tensorflow.op.kotlin
+
+import kotlin.Boolean
+import kotlin.Long
+import kotlin.String
+import org.tensorflow.ConcreteFunction
+import org.tensorflow.Operand
+import org.tensorflow.ndarray.Shape
+import org.tensorflow.op.Scope
+import org.tensorflow.op.`data`.AnonymousIterator
+import org.tensorflow.op.`data`.AssertCardinalityDataset
+import org.tensorflow.op.`data`.AssertNextDataset
+import org.tensorflow.op.`data`.AutoShardDataset
+import org.tensorflow.op.`data`.BatchDataset
+import org.tensorflow.op.`data`.BytesProducedStatsDataset
+import org.tensorflow.op.`data`.CSVDataset
+import org.tensorflow.op.`data`.CacheDataset
+import org.tensorflow.op.`data`.ChooseFastestBranchDataset
+import org.tensorflow.op.`data`.ChooseFastestDataset
+import org.tensorflow.op.`data`.ConcatenateDataset
+import org.tensorflow.op.`data`.DataServiceDatasetV2
+import org.tensorflow.op.`data`.DatasetCardinality
+import org.tensorflow.op.`data`.DatasetFromGraph
+import org.tensorflow.op.`data`.DatasetToGraph
+import org.tensorflow.op.`data`.DatasetToSingleElement
+import org.tensorflow.op.`data`.DatasetToTfRecord
+import org.tensorflow.op.`data`.DeleteIterator
+import org.tensorflow.op.`data`.DenseToSparseBatchDataset
+import org.tensorflow.op.`data`.DeserializeIterator
+import org.tensorflow.op.`data`.DirectedInterleaveDataset
+import org.tensorflow.op.`data`.FilterByLastComponentDataset
+import org.tensorflow.op.`data`.FilterDataset
+import org.tensorflow.op.`data`.FinalizeDataset
+import org.tensorflow.op.`data`.FixedLengthRecordDataset
+import org.tensorflow.op.`data`.FlatMapDataset
+import org.tensorflow.op.`data`.GeneratorDataset
+import org.tensorflow.op.`data`.GroupByReducerDataset
+import org.tensorflow.op.`data`.GroupByWindowDataset
+import org.tensorflow.op.`data`.IgnoreErrorsDataset
+import org.tensorflow.op.`data`.InitializeTableFromDataset
+import org.tensorflow.op.`data`.InterleaveDataset
+import org.tensorflow.op.`data`.Iterator
+import org.tensorflow.op.`data`.IteratorGetNext
+import org.tensorflow.op.`data`.IteratorGetNextAsOptional
+import org.tensorflow.op.`data`.IteratorGetNextSync
+import org.tensorflow.op.`data`.IteratorToStringHandle
+import org.tensorflow.op.`data`.LMDBDataset
+import org.tensorflow.op.`data`.LatencyStatsDataset
+import org.tensorflow.op.`data`.LegacyParallelInterleaveDataset
+import org.tensorflow.op.`data`.LoadDataset
+import org.tensorflow.op.`data`.MakeIterator
+import org.tensorflow.op.`data`.MapAndBatchDataset
+import org.tensorflow.op.`data`.MapDataset
+import org.tensorflow.op.`data`.MatchingFilesDataset
+import org.tensorflow.op.`data`.MaxIntraOpParallelismDataset
+import org.tensorflow.op.`data`.ModelDataset
+import org.tensorflow.op.`data`.NonSerializableDataset
+import org.tensorflow.op.`data`.OneShotIterator
+import org.tensorflow.op.`data`.OptimizeDataset
+import org.tensorflow.op.`data`.OptionalFromValue
+import org.tensorflow.op.`data`.OptionalGetValue
+import org.tensorflow.op.`data`.OptionalHasValue
+import org.tensorflow.op.`data`.OptionalNone
+import org.tensorflow.op.`data`.OptionsDataset
+import org.tensorflow.op.`data`.PaddedBatchDataset
+import org.tensorflow.op.`data`.ParallelBatchDataset
+import org.tensorflow.op.`data`.ParallelInterleaveDataset
+import org.tensorflow.op.`data`.ParallelMapDataset
+import org.tensorflow.op.`data`.ParseExampleDataset
+import org.tensorflow.op.`data`.PrefetchDataset
+import org.tensorflow.op.`data`.PrivateThreadPoolDataset
+import org.tensorflow.op.`data`.RandomDataset
+import org.tensorflow.op.`data`.RangeDataset
+import org.tensorflow.op.`data`.RebatchDatasetV2
+import org.tensorflow.op.`data`.ReduceDataset
+import org.tensorflow.op.`data`.RegisterDataset
+import org.tensorflow.op.`data`.RepeatDataset
+import org.tensorflow.op.`data`.SamplingDataset
+import org.tensorflow.op.`data`.SaveDataset
+import org.tensorflow.op.`data`.ScanDataset
+import org.tensorflow.op.`data`.SerializeIterator
+import org.tensorflow.op.`data`.SetStatsAggregatorDataset
+import org.tensorflow.op.`data`.ShardDataset
+import org.tensorflow.op.`data`.ShuffleAndRepeatDataset
+import org.tensorflow.op.`data`.ShuffleDataset
+import org.tensorflow.op.`data`.SkipDataset
+import org.tensorflow.op.`data`.SleepDataset
+import org.tensorflow.op.`data`.SlidingWindowDataset
+import org.tensorflow.op.`data`.SnapshotDataset
+import org.tensorflow.op.`data`.SparseTensorSliceDataset
+import org.tensorflow.op.`data`.SqlDataset
+import org.tensorflow.op.`data`.TakeDataset
+import org.tensorflow.op.`data`.TakeWhileDataset
+import org.tensorflow.op.`data`.TensorDataset
+import org.tensorflow.op.`data`.TensorSliceDataset
+import org.tensorflow.op.`data`.TextLineDataset
+import org.tensorflow.op.`data`.TfRecordDataset
+import org.tensorflow.op.`data`.ThreadPoolDataset
+import org.tensorflow.op.`data`.UnbatchDataset
+import org.tensorflow.op.`data`.UniqueDataset
+import org.tensorflow.op.`data`.UnwrapDatasetVariant
+import org.tensorflow.op.`data`.WindowDataset
+import org.tensorflow.op.`data`.WrapDatasetVariant
+import org.tensorflow.op.`data`.ZipDataset
+import org.tensorflow.types.TBool
+import org.tensorflow.types.TFloat32
+import org.tensorflow.types.TInt64
+import org.tensorflow.types.TString
+import org.tensorflow.types.family.TNumber
+import org.tensorflow.types.family.TType
+
+/**
+ * An API for building `data` operations as [Op][org.tensorflow.op.Op]s
+ *
+ * @see org.tensorflow.op.Ops
+ */
+public class DataOps(
+ /**
+ * Get the parent [KotlinOps] object.
+ */
+ public val ops: KotlinOps
+) {
+ public val java: org.tensorflow.op.DataOps = ops.java.data
+
+ /**
+ * Returns the current [scope][Scope] of this API
+ */
+ public val scope: Scope = ops.scope
+
+ /**
+ * A container for an iterator resource.
+ *
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @return a new instance of AnonymousIterator
+ * @see org.tensorflow.op.DataOps.anonymousIterator
+ */
+ public fun anonymousIterator(outputTypes: List>, outputShapes: List):
+ AnonymousIterator = java.anonymousIterator(
+ outputTypes,
+ outputShapes
+ )
+
+ /**
+ * The AssertCardinalityDataset operation
+ *
+ * @param inputDataset The inputDataset value
+ * @param cardinality The cardinality value
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @return a new instance of AssertCardinalityDataset
+ * @see org.tensorflow.op.DataOps.assertCardinalityDataset
+ */
+ public fun assertCardinalityDataset(
+ inputDataset: Operand,
+ cardinality: Operand,
+ outputTypes: List>,
+ outputShapes: List
+ ): AssertCardinalityDataset = java.assertCardinalityDataset(
+ inputDataset,
+ cardinality,
+ outputTypes,
+ outputShapes
+ )
+
+ /**
+ * A transformation that asserts which transformations happen next.
+ * This transformation checks whether the camel-case names (i.e. "FlatMap", not
+ * "flat_map") of the transformations following this transformation match the list
+ * of names in the `transformations` argument. If there is a mismatch, the
+ * transformation raises an exception.
+ *
+ * The check occurs when iterating over the contents of the dataset, which
+ * means that the check happens _after_ any static optimizations are applied
+ * to the dataset graph.
+ *
+ * @param inputDataset A variant tensor representing the input dataset.
+ * `data.AssertNextDataset` passes through the outputs of its input dataset.
+ * @param transformations A `tf.string` vector `tf.Tensor` identifying the transformations that
+ * are
+ * expected to happen next.
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @return a new instance of AssertNextDataset
+ * @see org.tensorflow.op.DataOps.assertNextDataset
+ */
+ public fun assertNextDataset(
+ inputDataset: Operand,
+ transformations: Operand,
+ outputTypes: List>,
+ outputShapes: List
+ ): AssertNextDataset = java.assertNextDataset(
+ inputDataset,
+ transformations,
+ outputTypes,
+ outputShapes
+ )
+
+ /**
+ * Creates a dataset that shards the input dataset.
+ * Creates a dataset that shards the input dataset by num_workers, returning a
+ * sharded dataset for the index-th worker. This attempts to automatically shard
+ * a dataset by examining the Dataset graph and inserting a shard op before the
+ * inputs to a reader Dataset (e.g. CSVDataset, TFRecordDataset).
+ *
+ * This dataset will throw a NotFound error if we cannot shard the dataset
+ * automatically.
+ *
+ * @param inputDataset A variant tensor representing the input dataset.
+ * @param numWorkers A scalar representing the number of workers to distribute this dataset
+ * across.
+ * @param index A scalar representing the index of the current worker out of num_workers.
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @param options carries optional attribute values
+ * @return a new instance of AutoShardDataset
+ * @see org.tensorflow.op.DataOps.autoShardDataset
+ * @param autoShardPolicy Sets the autoShardPolicy option.
+ *
+ * @param autoShardPolicy the autoShardPolicy option
+ * @return this Options instance.
+ * @param numReplicas Sets the numReplicas option.
+ *
+ * @param numReplicas the numReplicas option
+ * @return this Options instance.
+ */
+ public fun autoShardDataset(
+ inputDataset: Operand,
+ numWorkers: Operand,
+ index: Operand,
+ outputTypes: List>,
+ outputShapes: List,
+ autoShardPolicy: Long? = null,
+ numReplicas: Long? = null
+ ): AutoShardDataset = java.autoShardDataset(
+ inputDataset,
+ numWorkers,
+ index,
+ outputTypes,
+ outputShapes,
+ *listOfNotNull(
+ autoShardPolicy?.let{ org.tensorflow.op.data.AutoShardDataset.autoShardPolicy(it) },
+ numReplicas?.let{ org.tensorflow.op.data.AutoShardDataset.numReplicas(it) }
+ ).toTypedArray()
+ )
+
+ /**
+ * Creates a dataset that batches `batch_size` elements from `input_dataset`.
+ *
+ * @param inputDataset The inputDataset value
+ * @param batchSize A scalar representing the number of elements to accumulate in a batch.
+ * @param dropRemainder A scalar representing whether the last batch should be dropped in case
+ * its size
+ * is smaller than desired.
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @param options carries optional attribute values
+ * @return a new instance of BatchDataset
+ * @see org.tensorflow.op.DataOps.batchDataset
+ * @param parallelCopy Sets the parallelCopy option.
+ *
+ * @param parallelCopy the parallelCopy option
+ * @return this Options instance.
+ * @param metadata Sets the metadata option.
+ *
+ * @param metadata the metadata option
+ * @return this Options instance.
+ */
+ public fun batchDataset(
+ inputDataset: Operand,
+ batchSize: Operand,
+ dropRemainder: Operand,
+ outputTypes: List>,
+ outputShapes: List,
+ parallelCopy: Boolean? = null,
+ metadata: String? = null
+ ): BatchDataset = java.batchDataset(
+ inputDataset,
+ batchSize,
+ dropRemainder,
+ outputTypes,
+ outputShapes,
+ *listOfNotNull(
+ parallelCopy?.let{ org.tensorflow.op.data.BatchDataset.parallelCopy(it) },
+ metadata?.let{ org.tensorflow.op.data.BatchDataset.metadata(it) }
+ ).toTypedArray()
+ )
+
+ /**
+ * Records the bytes size of each element of `input_dataset` in a StatsAggregator.
+ *
+ * @param inputDataset The inputDataset value
+ * @param tag The tag value
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @return a new instance of BytesProducedStatsDataset
+ * @see org.tensorflow.op.DataOps.bytesProducedStatsDataset
+ */
+ public fun bytesProducedStatsDataset(
+ inputDataset: Operand,
+ tag: Operand,
+ outputTypes: List>,
+ outputShapes: List
+ ): BytesProducedStatsDataset = java.bytesProducedStatsDataset(
+ inputDataset,
+ tag,
+ outputTypes,
+ outputShapes
+ )
+
+ /**
+ * The CSVDatasetV2 operation
+ *
+ * @param filenames The filenames value
+ * @param compressionType The compressionType value
+ * @param bufferSize The bufferSize value
+ * @param header The header value
+ * @param fieldDelim The fieldDelim value
+ * @param useQuoteDelim The useQuoteDelim value
+ * @param naValue The naValue value
+ * @param selectCols The selectCols value
+ * @param recordDefaults The recordDefaults value
+ * @param excludeCols The excludeCols value
+ * @param outputShapes The value of the outputShapes attribute
+ * @return a new instance of CSVDataset
+ * @see org.tensorflow.op.DataOps.cSVDataset
+ */
+ public fun cSVDataset(
+ filenames: Operand,
+ compressionType: Operand,
+ bufferSize: Operand,
+ header: Operand,
+ fieldDelim: Operand,
+ useQuoteDelim: Operand,
+ naValue: Operand,
+ selectCols: Operand,
+ recordDefaults: Iterable>,
+ excludeCols: Operand,
+ outputShapes: List
+ ): CSVDataset = java.cSVDataset(
+ filenames,
+ compressionType,
+ bufferSize,
+ header,
+ fieldDelim,
+ useQuoteDelim,
+ naValue,
+ selectCols,
+ recordDefaults,
+ excludeCols,
+ outputShapes
+ )
+
+ /**
+ * The CacheDatasetV2 operation
+ *
+ * @param inputDataset The inputDataset value
+ * @param filename The filename value
+ * @param cache The cache value
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @param options carries optional attribute values
+ * @return a new instance of CacheDataset
+ * @see org.tensorflow.op.DataOps.cacheDataset
+ * @param metadata Sets the metadata option.
+ *
+ * @param metadata the metadata option
+ * @return this Options instance.
+ */
+ public fun cacheDataset(
+ inputDataset: Operand,
+ filename: Operand,
+ cache: Operand,
+ outputTypes: List>,
+ outputShapes: List,
+ metadata: String? = null
+ ): CacheDataset = java.cacheDataset(
+ inputDataset,
+ filename,
+ cache,
+ outputTypes,
+ outputShapes,
+ *listOfNotNull(
+ metadata?.let{ org.tensorflow.op.data.CacheDataset.metadata(it) }
+ ).toTypedArray()
+ )
+
+ /**
+ * The ChooseFastestBranchDataset operation
+ *
+ * @param inputDataset The inputDataset value
+ * @param ratioNumerator The ratioNumerator value
+ * @param ratioDenominator The ratioDenominator value
+ * @param otherArguments The otherArguments value
+ * @param numElementsPerBranch The value of the numElementsPerBranch attribute
+ * @param branches The value of the branches attribute
+ * @param otherArgumentsLengths The value of the otherArgumentsLengths attribute
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @return a new instance of ChooseFastestBranchDataset
+ * @see org.tensorflow.op.DataOps.chooseFastestBranchDataset
+ */
+ public fun chooseFastestBranchDataset(
+ inputDataset: Operand,
+ ratioNumerator: Operand,
+ ratioDenominator: Operand,
+ otherArguments: Iterable>,
+ numElementsPerBranch: Long,
+ branches: List,
+ otherArgumentsLengths: List,
+ outputTypes: List>,
+ outputShapes: List
+ ): ChooseFastestBranchDataset = java.chooseFastestBranchDataset(
+ inputDataset,
+ ratioNumerator,
+ ratioDenominator,
+ otherArguments,
+ numElementsPerBranch,
+ branches,
+ otherArgumentsLengths,
+ outputTypes,
+ outputShapes
+ )
+
+ /**
+ * The ChooseFastestDataset operation
+ *
+ * @param inputDatasets The inputDatasets value
+ * @param numExperiments The value of the numExperiments attribute
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @return a new instance of ChooseFastestDataset
+ * @see org.tensorflow.op.DataOps.chooseFastestDataset
+ */
+ public fun chooseFastestDataset(
+ inputDatasets: Iterable>,
+ numExperiments: Long,
+ outputTypes: List>,
+ outputShapes: List
+ ): ChooseFastestDataset = java.chooseFastestDataset(
+ inputDatasets,
+ numExperiments,
+ outputTypes,
+ outputShapes
+ )
+
+ /**
+ * Creates a dataset that concatenates `input_dataset` with `another_dataset`.
+ *
+ * @param inputDataset The inputDataset value
+ * @param anotherDataset The anotherDataset value
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @param options carries optional attribute values
+ * @return a new instance of ConcatenateDataset
+ * @see org.tensorflow.op.DataOps.concatenateDataset
+ * @param metadata Sets the metadata option.
+ *
+ * @param metadata the metadata option
+ * @return this Options instance.
+ */
+ public fun concatenateDataset(
+ inputDataset: Operand,
+ anotherDataset: Operand,
+ outputTypes: List>,
+ outputShapes: List,
+ metadata: String? = null
+ ): ConcatenateDataset = java.concatenateDataset(
+ inputDataset,
+ anotherDataset,
+ outputTypes,
+ outputShapes,
+ *listOfNotNull(
+ metadata?.let{ org.tensorflow.op.data.ConcatenateDataset.metadata(it) }
+ ).toTypedArray()
+ )
+
+ /**
+ * Creates a dataset that reads data from the tf.data service.
+ *
+ * @param datasetId The datasetId value
+ * @param processingMode The processingMode value
+ * @param address The address value
+ * @param protocol The protocol value
+ * @param jobName The jobName value
+ * @param consumerIndex The consumerIndex value
+ * @param numConsumers The numConsumers value
+ * @param maxOutstandingRequests The maxOutstandingRequests value
+ * @param iterationCounter The iterationCounter value
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @param options carries optional attribute values
+ * @return a new instance of DataServiceDatasetV2
+ * @see org.tensorflow.op.DataOps.dataServiceDatasetV2
+ * @param taskRefreshIntervalHintMs Sets the taskRefreshIntervalHintMs option.
+ *
+ * @param taskRefreshIntervalHintMs the taskRefreshIntervalHintMs option
+ * @return this Options instance.
+ * @param dataTransferProtocol Sets the dataTransferProtocol option.
+ *
+ * @param dataTransferProtocol the dataTransferProtocol option
+ * @return this Options instance.
+ * @param targetWorkers Sets the targetWorkers option.
+ *
+ * @param targetWorkers the targetWorkers option
+ * @return this Options instance.
+ */
+ public fun dataServiceDatasetV2(
+ datasetId: Operand,
+ processingMode: Operand,
+ address: Operand,
+ protocol: Operand,
+ jobName: Operand,
+ consumerIndex: Operand,
+ numConsumers: Operand,
+ maxOutstandingRequests: Operand,
+ iterationCounter: Operand,
+ outputTypes: List>,
+ outputShapes: List,
+ taskRefreshIntervalHintMs: Long? = null,
+ dataTransferProtocol: String? = null,
+ targetWorkers: String? = null
+ ): DataServiceDatasetV2 = java.dataServiceDatasetV2(
+ datasetId,
+ processingMode,
+ address,
+ protocol,
+ jobName,
+ consumerIndex,
+ numConsumers,
+ maxOutstandingRequests,
+ iterationCounter,
+ outputTypes,
+ outputShapes,
+ *listOfNotNull(
+ taskRefreshIntervalHintMs?.let{
+ org.tensorflow.op.data.DataServiceDatasetV2.taskRefreshIntervalHintMs(it) },
+ dataTransferProtocol?.let{
+ org.tensorflow.op.data.DataServiceDatasetV2.dataTransferProtocol(it) },
+ targetWorkers?.let{ org.tensorflow.op.data.DataServiceDatasetV2.targetWorkers(it) }
+ ).toTypedArray()
+ )
+
+ /**
+ * Returns the cardinality of `input_dataset`.
+ * Returns the cardinality of `input_dataset`.
+ *
+ * @param inputDataset A variant tensor representing the dataset to return cardinality for.
+ * @return a new instance of DatasetCardinality
+ * @see org.tensorflow.op.DataOps.datasetCardinality
+ */
+ public fun datasetCardinality(inputDataset: Operand): DatasetCardinality =
+ java.datasetCardinality(
+ inputDataset
+ )
+
+ /**
+ * Creates a dataset from the given `graph_def`.
+ * Creates a dataset from the provided `graph_def`.
+ *
+ * @param graphDef The graph representation of the dataset (as serialized GraphDef).
+ * @return a new instance of DatasetFromGraph
+ * @see org.tensorflow.op.DataOps.datasetFromGraph
+ */
+ public fun datasetFromGraph(graphDef: Operand): DatasetFromGraph =
+ java.datasetFromGraph(
+ graphDef
+ )
+
+ /**
+ * Returns a serialized GraphDef representing `input_dataset`.
+ * Returns a graph representation for `input_dataset`.
+ *
+ * @param inputDataset A variant tensor representing the dataset to return the graph
+ * representation for.
+ * @param options carries optional attribute values
+ * @return a new instance of DatasetToGraph
+ * @see org.tensorflow.op.DataOps.datasetToGraph
+ * @param externalStatePolicy Sets the externalStatePolicy option.
+ *
+ * @param externalStatePolicy the externalStatePolicy option
+ * @return this Options instance.
+ * @param stripDeviceAssignment Sets the stripDeviceAssignment option.
+ *
+ * @param stripDeviceAssignment the stripDeviceAssignment option
+ * @return this Options instance.
+ */
+ public fun datasetToGraph(
+ inputDataset: Operand,
+ externalStatePolicy: Long? = null,
+ stripDeviceAssignment: Boolean? = null
+ ): DatasetToGraph = java.datasetToGraph(
+ inputDataset,
+ *listOfNotNull(
+ externalStatePolicy?.let{ org.tensorflow.op.data.DatasetToGraph.externalStatePolicy(it) },
+ stripDeviceAssignment?.let{ org.tensorflow.op.data.DatasetToGraph.stripDeviceAssignment(it) }
+ ).toTypedArray()
+ )
+
+ /**
+ * Outputs the single element from the given dataset.
+ *
+ * @param dataset A handle to a dataset that contains a single element.
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @param options carries optional attribute values
+ * @return a new instance of DatasetToSingleElement
+ * @see org.tensorflow.op.DataOps.datasetToSingleElement
+ * @param metadata Sets the metadata option.
+ *
+ * @param metadata the metadata option
+ * @return this Options instance.
+ */
+ public fun datasetToSingleElement(
+ dataset: Operand,
+ outputTypes: List>,
+ outputShapes: List,
+ metadata: String? = null
+ ): DatasetToSingleElement = java.datasetToSingleElement(
+ dataset,
+ outputTypes,
+ outputShapes,
+ *listOfNotNull(
+ metadata?.let{ org.tensorflow.op.data.DatasetToSingleElement.metadata(it) }
+ ).toTypedArray()
+ )
+
+ /**
+ * Writes the given dataset to the given file using the TFRecord format.
+ *
+ * @param inputDataset A variant tensor representing the dataset to write.
+ * @param filename A scalar string tensor representing the filename to use.
+ * @param compressionType A scalar string tensor containing either (i) the empty string (no
+ * compression), (ii) "ZLIB", or (iii) "GZIP".
+ * @return a new instance of DatasetToTfRecord
+ * @see org.tensorflow.op.DataOps.datasetToTfRecord
+ */
+ public fun datasetToTfRecord(
+ inputDataset: Operand,
+ filename: Operand,
+ compressionType: Operand
+ ): DatasetToTfRecord = java.datasetToTfRecord(
+ inputDataset,
+ filename,
+ compressionType
+ )
+
+ /**
+ * A container for an iterator resource.
+ *
+ * @param handle A handle to the iterator to delete.
+ * @param deleter A variant deleter.
+ * @return a new instance of DeleteIterator
+ * @see org.tensorflow.op.DataOps.deleteIterator
+ */
+ public fun deleteIterator(handle: Operand, deleter: Operand):
+ DeleteIterator = java.deleteIterator(
+ handle,
+ deleter
+ )
+
+ /**
+ * Creates a dataset that batches input elements into a SparseTensor.
+ *
+ * @param inputDataset A handle to an input dataset. Must have a single component.
+ * @param batchSize A scalar representing the number of elements to accumulate in a
+ * batch.
+ * @param rowShape A vector representing the dense shape of each row in the produced
+ * SparseTensor. The shape may be partially specified, using `-1` to indicate
+ * that a particular dimension should use the maximum size of all batch elements.
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @return a new instance of DenseToSparseBatchDataset
+ * @see org.tensorflow.op.DataOps.denseToSparseBatchDataset
+ */
+ public fun denseToSparseBatchDataset(
+ inputDataset: Operand,
+ batchSize: Operand,
+ rowShape: Operand,
+ outputTypes: List>,
+ outputShapes: List
+ ): DenseToSparseBatchDataset = java.denseToSparseBatchDataset(
+ inputDataset,
+ batchSize,
+ rowShape,
+ outputTypes,
+ outputShapes
+ )
+
+ /**
+ * Converts the given variant tensor to an iterator and stores it in the given resource.
+ *
+ * @param resourceHandle A handle to an iterator resource.
+ * @param serialized A variant tensor storing the state of the iterator contained in the
+ * resource.
+ * @return a new instance of DeserializeIterator
+ * @see org.tensorflow.op.DataOps.deserializeIterator
+ */
+ public fun deserializeIterator(resourceHandle: Operand, serialized: Operand): DeserializeIterator = java.deserializeIterator(
+ resourceHandle,
+ serialized
+ )
+
+ /**
+ * A substitute for `InterleaveDataset` on a fixed list of `N` datasets.
+ *
+ * @param selectorInputDataset A dataset of scalar `DT_INT64` elements that determines which of
+ * the
+ * `N` data inputs should produce the next output element.
+ * @param dataInputDatasets `N` datasets with the same type that will be interleaved according
+ * to
+ * the values of `selector_input_dataset`.
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @param options carries optional attribute values
+ * @return a new instance of DirectedInterleaveDataset
+ * @see org.tensorflow.op.DataOps.directedInterleaveDataset
+ * @param stopOnEmptyDataset Sets the stopOnEmptyDataset option.
+ *
+ * @param stopOnEmptyDataset the stopOnEmptyDataset option
+ * @return this Options instance.
+ */
+ public fun directedInterleaveDataset(
+ selectorInputDataset: Operand,
+ dataInputDatasets: Iterable>,
+ outputTypes: List>,
+ outputShapes: List,
+ stopOnEmptyDataset: Boolean? = null
+ ): DirectedInterleaveDataset = java.directedInterleaveDataset(
+ selectorInputDataset,
+ dataInputDatasets,
+ outputTypes,
+ outputShapes,
+ *listOfNotNull(
+ stopOnEmptyDataset?.let{
+ org.tensorflow.op.data.DirectedInterleaveDataset.stopOnEmptyDataset(it) }
+ ).toTypedArray()
+ )
+
+ /**
+ * Creates a dataset containing elements of first component of `input_dataset` having true in
+ * the last component.
+ *
+ * @param inputDataset The inputDataset value
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @return a new instance of FilterByLastComponentDataset
+ * @see org.tensorflow.op.DataOps.filterByLastComponentDataset
+ */
+ public fun filterByLastComponentDataset(
+ inputDataset: Operand,
+ outputTypes: List>,
+ outputShapes: List
+ ): FilterByLastComponentDataset = java.filterByLastComponentDataset(
+ inputDataset,
+ outputTypes,
+ outputShapes
+ )
+
+ /**
+ * Creates a dataset containing elements of `input_dataset` matching `predicate`.
+ * The `predicate` function must return a scalar boolean and accept the
+ * following arguments:
+ *
+ *
One tensor for each component of an element of `input_dataset`.
+ *
One tensor for each value in `other_arguments`.
+ *
+ *
+ * @param inputDataset The inputDataset value
+ * @param otherArguments A list of tensors, typically values that were captured when
+ * building a closure for `predicate`.
+ * @param predicate A function returning a scalar boolean.
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @param options carries optional attribute values
+ * @return a new instance of FilterDataset
+ * @see org.tensorflow.op.DataOps.filterDataset
+ * @param metadata Sets the metadata option.
+ *
+ * @param metadata the metadata option
+ * @return this Options instance.
+ */
+ public fun filterDataset(
+ inputDataset: Operand,
+ otherArguments: Iterable>,
+ predicate: ConcreteFunction,
+ outputTypes: List>,
+ outputShapes: List,
+ metadata: String? = null
+ ): FilterDataset = java.filterDataset(
+ inputDataset,
+ otherArguments,
+ predicate,
+ outputTypes,
+ outputShapes,
+ *listOfNotNull(
+ metadata?.let{ org.tensorflow.op.data.FilterDataset.metadata(it) }
+ ).toTypedArray()
+ )
+
+ /**
+ * Creates a dataset by applying `tf.data.Options` to `input_dataset`.
+ *
+ * @param inputDataset A variant tensor representing the input dataset.
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @param options carries optional attribute values
+ * @return a new instance of FinalizeDataset
+ * @see org.tensorflow.op.DataOps.finalizeDataset
+ * @param hasCapturedRef Sets the hasCapturedRef option.
+ *
+ * @param hasCapturedRef the hasCapturedRef option
+ * @return this Options instance.
+ */
+ public fun finalizeDataset(
+ inputDataset: Operand,
+ outputTypes: List>,
+ outputShapes: List,
+ hasCapturedRef: Boolean? = null
+ ): FinalizeDataset = java.finalizeDataset(
+ inputDataset,
+ outputTypes,
+ outputShapes,
+ *listOfNotNull(
+ hasCapturedRef?.let{ org.tensorflow.op.data.FinalizeDataset.hasCapturedRef(it) }
+ ).toTypedArray()
+ )
+
+ /**
+ * The FixedLengthRecordDatasetV2 operation
+ *
+ * @param filenames The filenames value
+ * @param headerBytes The headerBytes value
+ * @param recordBytes The recordBytes value
+ * @param footerBytes The footerBytes value
+ * @param bufferSize The bufferSize value
+ * @param compressionType The compressionType value
+ * @param options carries optional attribute values
+ * @return a new instance of FixedLengthRecordDataset
+ * @see org.tensorflow.op.DataOps.fixedLengthRecordDataset
+ * @param metadata Sets the metadata option.
+ *
+ * @param metadata the metadata option
+ * @return this Options instance.
+ */
+ public fun fixedLengthRecordDataset(
+ filenames: Operand,
+ headerBytes: Operand,
+ recordBytes: Operand,
+ footerBytes: Operand,
+ bufferSize: Operand,
+ compressionType: Operand,
+ metadata: String? = null
+ ): FixedLengthRecordDataset = java.fixedLengthRecordDataset(
+ filenames,
+ headerBytes,
+ recordBytes,
+ footerBytes,
+ bufferSize,
+ compressionType,
+ *listOfNotNull(
+ metadata?.let{ org.tensorflow.op.data.FixedLengthRecordDataset.metadata(it) }
+ ).toTypedArray()
+ )
+
+ /**
+ * Creates a dataset that applies `f` to the outputs of `input_dataset`.
+ * Unlike MapDataset, the `f` in FlatMapDataset is expected to return a
+ * Dataset variant, and FlatMapDataset will flatten successive results
+ * into a single Dataset.
+ *
+ * @param inputDataset The inputDataset value
+ * @param otherArguments The otherArguments value
+ * @param f A function mapping elements of `input_dataset`, concatenated with
+ * `other_arguments`, to a Dataset variant that contains elements matching
+ * `output_types` and `output_shapes`.
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @param options carries optional attribute values
+ * @return a new instance of FlatMapDataset
+ * @see org.tensorflow.op.DataOps.flatMapDataset
+ * @param metadata Sets the metadata option.
+ *
+ * @param metadata the metadata option
+ * @return this Options instance.
+ */
+ public fun flatMapDataset(
+ inputDataset: Operand,
+ otherArguments: Iterable>,
+ f: ConcreteFunction,
+ outputTypes: List>,
+ outputShapes: List,
+ metadata: String? = null
+ ): FlatMapDataset = java.flatMapDataset(
+ inputDataset,
+ otherArguments,
+ f,
+ outputTypes,
+ outputShapes,
+ *listOfNotNull(
+ metadata?.let{ org.tensorflow.op.data.FlatMapDataset.metadata(it) }
+ ).toTypedArray()
+ )
+
+ /**
+ * Creates a dataset that invokes a function to generate elements.
+ *
+ * @param initFuncOtherArgs The initFuncOtherArgs value
+ * @param nextFuncOtherArgs The nextFuncOtherArgs value
+ * @param finalizeFuncOtherArgs The finalizeFuncOtherArgs value
+ * @param initFunc The value of the initFunc attribute
+ * @param nextFunc The value of the nextFunc attribute
+ * @param finalizeFunc The value of the finalizeFunc attribute
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @param options carries optional attribute values
+ * @return a new instance of GeneratorDataset
+ * @see org.tensorflow.op.DataOps.generatorDataset
+ * @param metadata Sets the metadata option.
+ *
+ * @param metadata the metadata option
+ * @return this Options instance.
+ */
+ public fun generatorDataset(
+ initFuncOtherArgs: Iterable>,
+ nextFuncOtherArgs: Iterable>,
+ finalizeFuncOtherArgs: Iterable>,
+ initFunc: ConcreteFunction,
+ nextFunc: ConcreteFunction,
+ finalizeFunc: ConcreteFunction,
+ outputTypes: List>,
+ outputShapes: List,
+ metadata: String? = null
+ ): GeneratorDataset = java.generatorDataset(
+ initFuncOtherArgs,
+ nextFuncOtherArgs,
+ finalizeFuncOtherArgs,
+ initFunc,
+ nextFunc,
+ finalizeFunc,
+ outputTypes,
+ outputShapes,
+ *listOfNotNull(
+ metadata?.let{ org.tensorflow.op.data.GeneratorDataset.metadata(it) }
+ ).toTypedArray()
+ )
+
+ /**
+ * Creates a dataset that computes a group-by on `input_dataset`.
+ * Creates a dataset that computes a group-by on `input_dataset`.
+ *
+ * @param inputDataset A variant tensor representing the input dataset.
+ * @param keyFuncOtherArguments A list of tensors, typically values that were captured when
+ * building a closure for `key_func`.
+ * @param initFuncOtherArguments A list of tensors, typically values that were captured when
+ * building a closure for `init_func`.
+ * @param reduceFuncOtherArguments A list of tensors, typically values that were captured when
+ * building a closure for `reduce_func`.
+ * @param finalizeFuncOtherArguments A list of tensors, typically values that were captured when
+ * building a closure for `finalize_func`.
+ * @param keyFunc A function mapping an element of `input_dataset`, concatenated
+ * with `key_func_other_arguments` to a scalar value of type DT_INT64.
+ * @param initFunc A function mapping a key of type DT_INT64, concatenated with
+ * `init_func_other_arguments` to the initial reducer state.
+ * @param reduceFunc A function mapping the current reducer state and an element of
+ * `input_dataset`,
+ * concatenated with `reduce_func_other_arguments` to a new reducer state.
+ * @param finalizeFunc A function mapping the final reducer state to an output element.
+ * @param outputTypes The value of the outputTypes attribute
+ * @param outputShapes The value of the outputShapes attribute
+ * @return a new instance of GroupByReducerDataset
+ * @see org.tensorflow.op.DataOps.groupByReducerDataset
+ */
+ public fun groupByReducerDataset(
+ inputDataset: Operand,
+ keyFuncOtherArguments: Iterable>,
+ initFuncOtherArguments: Iterable