Skip to content

Kotlin API #165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 61 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
82913de
Merge in device spec changes
rnett Dec 1, 2020
25da497
Initial kotlin generation (still using Options), some helpers
rnett Dec 2, 2020
4c2ef19
Update stdlib, more helpers
rnett Dec 2, 2020
2638f47
default parameters
rnett Dec 2, 2020
14d301c
Add and use ktlint
rnett Dec 3, 2020
02494e5
use spaces
rnett Dec 3, 2020
1ae143c
disable filename rule
rnett Dec 3, 2020
07fcb24
change disable to single file
rnett Dec 3, 2020
e3871e5
Helper methods for withDevice, a combined with method, and tf(DeviceS…
rnett Dec 4, 2020
881237b
Add license
rnett Dec 6, 2020
aab7311
WIP Session/Runner API. Java API needs updates
rnett Dec 7, 2020
9334c95
Javadoc generation
rnett Dec 7, 2020
7babe57
Example, clean up session api
rnett Dec 7, 2020
cbd611d
make the test public
rnett Dec 7, 2020
c42a092
add full test dependencies
rnett Dec 7, 2020
5b5a1bb
Add ones op
rnett Dec 7, 2020
535793b
requireShape helper methods
rnett Dec 10, 2020
12c6238
fix lint
rnett Dec 10, 2020
1686cec
Rename withSession to useSession to reflect closing semantics
rnett Dec 11, 2020
7de79d3
Target JVM 1.8 for Kotlin
rnett Dec 11, 2020
4d311ed
Update Kotlin version
rnett Dec 27, 2020
ea9e171
Codegen for reified type parameters
rnett Dec 28, 2020
e948269
disable auto-format for now (ktlint bug)
rnett Dec 28, 2020
3777965
Data type helpers
rnett Dec 30, 2020
8629630
Cleanup poms
rnett Jan 14, 2021
137277c
Concrete function helpers, redo codegen
rnett Jan 14, 2021
234d3ef
formatting
rnett Jan 14, 2021
26b4a69
Shape property
rnett Jan 24, 2021
9ee2f48
New codegen, support for Java 11 builds
rnett Jan 26, 2021
63e8e25
Start of extension helpers
rnett Feb 5, 2021
8271b70
Update to new master
rnett Feb 11, 2021
3c5467a
Rebase updates
rnett Mar 19, 2021
cca2857
Add tests
rnett Mar 20, 2021
a5a5f46
Add section in CONTRIBUTING
rnett Mar 20, 2021
01257f8
Add readme w/ link to contributing instructions
rnett Mar 20, 2021
d0a0183
Add section to readme
rnett Mar 20, 2021
fee69b2
Restructure Kotlin projects
rnett Apr 23, 2021
1677979
Fix name
rnett Apr 23, 2021
9b87d40
Rename Shape.size(int) to get, add toListOrNull
rnett Apr 23, 2021
f8c236b
Rebase
rnett Apr 23, 2021
4fed8ac
Update javadoc generation
rnett Apr 23, 2021
5113983
Update to Kotlin 1.5.0
rnett May 15, 2021
b7c6660
Rebase fixes
rnett Jun 20, 2021
353d137
Update formatting to use ktfmt and spotless
rnett Jun 20, 2021
c6eef4f
No formatting on generation, update explicit API settings
rnett Jun 20, 2021
103d70e
Initial framework wrappers
rnett Jun 20, 2021
3152bf0
Add WithOps, use for KotlinOps
rnett Jun 20, 2021
775860e
Better shape assertions
rnett Jun 20, 2021
066b27f
Formatting, Jupyter integration
rnett Jun 21, 2021
45ba2dd
Fix formatting
rnett Jun 21, 2021
419afc3
Fix test
rnett Jun 21, 2021
64a8dda
Fix formatting
rnett Jun 21, 2021
ef033fb
Fix generation regression
rnett Jun 21, 2021
f017a7c
Fix format
rnett Jun 21, 2021
482b43c
Don't load extra snapshot repo. If we're on snapshots we'll have had…
rnett Jul 1, 2021
b7a85ef
Use extension instead of platform
rnett Jul 1, 2021
693e108
Update to Kotlin 1.6
rnett Dec 22, 2021
42c61b1
Fix formatting
rnett Dec 22, 2021
12f56df
Update generation
rnett Dec 22, 2021
27fcac5
Update version
rnett Dec 22, 2021
43cd7e3
Update generation
rnett Dec 22, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,11 @@ For dependencies, we can use anything compliant with [this list](https://opensou

### Code generation

Code generation for `Ops` and related classes is done during `tensorflow-core-api`'s `compile` phase, using the annotation processor in
`tensorflow-core-generator`. If you change or add any operator classes (annotated with `org.tensorflow.op.annotation.Operator`), endpoint methods (
annotated with `org.tensorflow.op.annotation.Endpoint`), or change the annotation processor, be sure to re-run a
`mvn install` in `tensorflow-core-api` (`-Pdev` is fine for this, it just needs to run the annotation processor).
Code generation for `Ops` and related classes is done during `tensorflow-core-api` and `tensorflow-core-kotlin`'s `compile` phase,
using the annotation processors in `tensorflow-core-generator` and `tensorflow-kotlin-generator`, respectively. If you change or add any
operator classes (annotated with `org.tensorflow.op.annotation.Operator`), endpoint methods (annotated with `org.tensorflow.op.annotation.Endpoint`),
or change the annotation processor, be sure to re-run a `mvn compile` in `tensorflow-core-api` **and** `tensorflow-core-kotlin`
(`-Pdev` is fine for this, it just needs to run the annotation processor).

### Working with Bazel generation

Expand All @@ -189,6 +190,19 @@ bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/libtensorflow_cc.so --ou

(called in `tensorflow-core-api`).

### Kotlin API

The Kotlin api should be kept to a thin wrapper of the Java API, using extension functions and codegen wherever possible.
We do not want to get into a situation where we are maintaining two separate but related APIs.

The codegen (`tensorflow-core-kotlin-generator`) is an annotation processor that reads the `@Operator` classes from the `tensorflow-core-api` Java sources.
If you add operators or re-generate them from the native library, be sure to re-run a `mvn install` in `tensorflow-core-kotlin-api`.

#### Formatting

[ktfmt](https://github.com/facebookincubator/ktfmt) is used to format the Kotlin files. This is
checked and done via maven in the same way as Java formatting. To do the formatting via IntelliJ see
ktfmt's repo.

## Adding Gradients

Expand Down
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@ The following describes the layout of the repository and its different artifacts
* `tensorflow-core`
* All artifacts that build up the core language bindings of TensorFlow for Java
* Intended audience: projects that provide their own APIs or frameworks on top of
TensorFlow and just want a thin layer to access the TensorFlow runtime from the JVM
TensorFlow and just want a thin layer to access the TensorFlow runtime from the JVM

* `tensorflow-core-kotlin`
* Kotlin API bindings for `tensorflow-core`. These are thin wrappers around the core APIs
to make them more idiomatic for use in Kotlin, such as using parameters with default values
operation builders instead of an `Options` vararg.

* `tensorflow-framework`
* Primary API for building and training neural networks with TensorFlow
* Intended audience: neural network developers
Expand Down Expand Up @@ -112,6 +117,12 @@ the platforms you are targeting. For this purpose the `-platform` artifacts incl
the conventions established on this page:
* [Reducing the Number of Dependencies](https://github.com/bytedeco/javacpp-presets/wiki/Reducing-the-Number-of-Dependencies)

### Kotlin API

Since the Kotlin API is just a wrapper of the Java API, it uses the Java platform artifacts instead of providing its own.
To use, follow the instructions above for the Java API, but add `tensorflow-core-kotlin-api`,
replacing `tensorflow-core-api` if you have explicitly included it.

### Snapshots

Snapshots of TensorFlow Java artifacts are automatically distributed after each update in the code. To use them, you need
Expand Down
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

<modules>
<module>tensorflow-core</module>
<module>tensorflow-kotlin-parent</module>
<module>tensorflow-framework</module>
</modules>

Expand Down
2 changes: 1 addition & 1 deletion tensorflow-core/tensorflow-core-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
<javacpp.parser.skip>${native.build.skip}</javacpp.parser.skip>
<javacpp.compiler.skip>${native.build.skip}</javacpp.compiler.skip>
<java.module.name>org.tensorflow.core.api</java.module.name>
<ndarray.version>0.3.3</ndarray.version>
<ndarray.version>0.4.0-SNAPSHOT</ndarray.version>
<truth.version>1.0.1</truth.version>
</properties>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@
* }
* }</pre>
*/
public final class Ops {
public final class Ops implements WithOps {
public final NnOps nn;

public final SummaryOps summary;
Expand Down Expand Up @@ -371,10 +371,10 @@ public final class Ops {

public final TpuOps tpu;

public final AudioOps audio;

public final MathOps math;

public final AudioOps audio;

public final SignalOps signal;

public final TrainOps train;
Expand All @@ -400,8 +400,8 @@ public final class Ops {
sparse = new SparseOps(this);
bitwise = new BitwiseOps(this);
tpu = new TpuOps(this);
audio = new AudioOps(this);
math = new MathOps(this);
audio = new AudioOps(this);
signal = new SignalOps(this);
train = new TrainOps(this);
quantization = new QuantizationOps(this);
Expand Down Expand Up @@ -8068,11 +8068,15 @@ public <T extends TType> ZerosLike<T> zerosLike(Operand<T> x) {
return ZerosLike.create(scope, x);
}

@Override
public Ops tf() {
return this;
}

/**
* Returns an API that builds operations with the provided name prefix.
*
* @see {@link Scope#withSubScope(String)}
* {@inheritDoc}
*/
@Override
public Ops withSubScope(String childScopeName) {
return new Ops(scope.withSubScope(childScopeName));
}
Expand Down Expand Up @@ -8109,28 +8113,25 @@ public <T extends Operand> T liftToInitScope(T op) {
}

/**
* Returns an API that uses the provided name for an op.
*
* @see {@link Scope#withName(String)}
* {@inheritDoc}
*/
@Override
public Ops withName(String opName) {
return new Ops(scope.withName(opName));
}

/**
* Returns an API that places the created operations on the device(s) matching the provided spec.
*
* @see {@link Scope#withDevice(DeviceSpec)}
* {@inheritDoc}
*/
@Override
public Ops withDevice(DeviceSpec deviceSpec) {
return new Ops(scope.withDevice(deviceSpec));
}

/**
* Returns an API that adds operations to the graph with the provided control dependencies.
*
* @see {@link Scope#withControlDependencies(Iterable<Op<?>>)}
* {@inheritDoc}
*/
@Override
public Ops withControlDependencies(Iterable<Op> controls) {
return new Ops(scope.withControlDependencies(controls));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
package org.tensorflow;

import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.Scope;
import org.tensorflow.op.WithOps;

/** Defines an environment for creating and executing TensorFlow {@link Operation}s. */
public interface ExecutionEnvironment {
public interface ExecutionEnvironment extends WithOps {

enum Types {
GRAPH,
Expand Down Expand Up @@ -126,4 +128,9 @@ default ExecutionEnvironment initEnv() {
* <p><b>Should generally only be used internally.</b>
*/
boolean isInitOp(Operation op);

@Override
default Ops tf() {
return Ops.create(this);
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
*/
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
*/
package org.tensorflow;

import java.util.Collections;
Expand Down Expand Up @@ -161,7 +161,7 @@ private static TensorInfo toTensorInfo(Output<?> operand) {
Shape shape = operand.shape();
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
for (int i = 0; i < shape.numDimensions(); ++i) {
tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i)));
tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.get(i)));
}
return TensorInfo.newBuilder()
.setDtype(operand.dataType())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================

*/
package org.tensorflow.op;

import java.util.Arrays;
import org.tensorflow.DeviceSpec;

/** A context that provides a TensorFlow op builder. */
public interface WithOps {

/** Get the op builder for this context. */
Ops tf();

/**
* Returns an API that builds operations with the provided name prefix.
*
* @see Scope#withSubScope(String)
*/
default WithOps withSubScope(String childScopeName) {
return tf().withSubScope(childScopeName);
}

/**
* Returns an API that uses the provided name for an op.
*
* @see Scope#withName(String)
*/
default WithOps withName(String opName) {
return tf().withName(opName);
}

/**
* Returns an API that places the created operations on the device(s) matching the provided spec.
*
* @see Scope#withDevice(DeviceSpec)
*/
default WithOps withDevice(DeviceSpec deviceSpec) {
return tf().withDevice(deviceSpec);
}

/**
* Returns an API that adds operations to the graph with the provided control dependencies.
*
* @see Scope#withControlDependencies(Iterable)
*/
default WithOps withControlDependencies(Iterable<Op> controls) {
return tf().withControlDependencies(controls);
}

/**
* Returns an API that adds operations to the graph with the provided control dependencies.
*
* @see Scope#withControlDependencies(Iterable)
*/
default WithOps withControlDependencies(Op... controls) {
return withControlDependencies(Arrays.asList(controls));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ public void outputDataTypeAndShape() {
.setAttr("value", t)
.build();
assertEquals(DataType.DT_INT32, op.dtype(0));
assertEquals(2, op.shape(0).size(0));
assertEquals(3, op.shape(0).size(1));
assertEquals(2, op.shape(0).get(0));
assertEquals(3, op.shape(0).get(1));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ public void setAttrShape() {
.build()
.output(0);
assertEquals(2, n.shape().numDimensions());
assertEquals(-1, n.shape().size(0));
assertEquals(784, n.shape().size(1));
assertEquals(-1, n.shape().get(0));
assertEquals(784, n.shape().get(1));
assertEquals(DataType.DT_FLOAT, n.dataType());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ public void exportFunctionWithVariables() throws IOException {
assertNotNull(inputInfo);
assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount());
for (int i = 0; i < xyShape.numDimensions(); ++i) {
assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize());
assertEquals(xyShape.get(i), inputInfo.getTensorShape().getDim(i).getSize());
}

TensorInfo outputInfo = signatureDef.getOutputsMap().get("reducedSum");
Expand Down
Loading