Skip to content

Commit 83b7ba2

Browse files
committed
Restory history in Java/JNI source files
2 parents fa38988 + f7bcc75 commit 83b7ba2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+10664
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
package org.tensorflow;
17+
18+
/**
19+
* Base class for {@link Operation} implementations.
20+
*
21+
* <p>As opposed to {@link Operation} itself, this class is package private and therefore its usage
22+
* is limited to internal purposes only.
23+
*/
24+
abstract class AbstractOperation implements Operation {
25+
26+
@Override
27+
public Output<?>[] outputList(int idx, int length) {
28+
Output<?>[] outputs = new Output<?>[length];
29+
for (int i = 0; i < length; ++i) {
30+
outputs[i] = output(idx + i);
31+
}
32+
return outputs;
33+
}
34+
35+
@Override
36+
@SuppressWarnings({"rawtypes", "unchecked"})
37+
public <T> Output<T> output(int idx) {
38+
return new Output(this, idx);
39+
}
40+
41+
@Override
42+
public String toString() {
43+
return String.format("<%s '%s'>", type(), name());
44+
}
45+
46+
/**
47+
* Returns the native handle of the {@code outputIdx}th output of this operation.
48+
*
49+
* <p>The nature of the returned value varies depending on current the execution environment.
50+
*
51+
* <ul>
52+
* <li>In eager mode, the value is a handle to the tensor returned at this output.
53+
* <li>In graph mode, the value is a handle to the operation itself, which should be paired with
54+
* the index of the output when calling the native layer.
55+
* </ul>
56+
*
57+
* @param outputIdx index of the output in this operation
58+
* @return a native handle, see method description for more details
59+
*/
60+
abstract long getUnsafeNativeHandle(int outputIdx);
61+
62+
/**
63+
* Returns the shape of the tensor of the {@code outputIdx}th output of this operation.
64+
*
65+
* @param outputIdx index of the output of this operation
66+
* @return output tensor shape
67+
*/
68+
abstract long[] shape(int outputIdx);
69+
70+
/**
71+
* Returns the datatype of the tensor of the {@code outputIdx}th output of this operation.
72+
*
73+
* @param outputIdx index of the output of this operation
74+
* @return output tensor datatype
75+
*/
76+
abstract DataType dtype(int outputIdx);
77+
78+
/**
79+
* Returns the tensor of the {@code outputIdx}th output of this operation.
80+
*
81+
* <p>This is only supported in an eager execution environment.
82+
*
83+
* @param outputIdx index of the output of this operation
84+
* @return output tensor
85+
*/
86+
abstract Tensor<?> tensor(int outputIdx);
87+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
package org.tensorflow;
17+
18+
import java.util.HashMap;
19+
import java.util.Map;
20+
21+
import org.tensorflow.types.UInt8;
22+
23+
/** Represents the type of elements in a {@link Tensor} as an enum. */
24+
public enum DataType {
25+
/** 32-bit single precision floating point. */
26+
FLOAT(1, 4),
27+
28+
/** 64-bit double precision floating point. */
29+
DOUBLE(2, 8),
30+
31+
/** 32-bit signed integer. */
32+
INT32(3, 4),
33+
34+
/** 8-bit unsigned integer. */
35+
UINT8(4, 1),
36+
37+
/**
38+
* A sequence of bytes.
39+
*
40+
* <p>TensorFlow uses the STRING type for an arbitrary sequence of bytes.
41+
*/
42+
STRING(7, -1),
43+
44+
/** 64-bit signed integer. */
45+
INT64(9, 8),
46+
47+
/** Boolean. */
48+
BOOL(10, 1);
49+
50+
private final int value;
51+
52+
private final int byteSize;
53+
54+
/**
55+
* @param value must match the corresponding TF_* value in the TensorFlow C API.
56+
* @param byteSize size of an element of this type, in bytes, -1 if unknown
57+
*/
58+
DataType(int value, int byteSize) {
59+
this.value = value;
60+
this.byteSize = byteSize;
61+
}
62+
63+
/**
64+
* Returns the size of an element of this type, in bytes, or -1 if element size is variable.
65+
*/
66+
public int byteSize() {
67+
return byteSize;
68+
}
69+
70+
/** Corresponding value of the TF_DataType enum in the TensorFlow C API. */
71+
int c() {
72+
return value;
73+
}
74+
75+
// Cached to avoid copying it
76+
private static final DataType[] values = values();
77+
78+
static DataType fromC(int c) {
79+
for (DataType t : values) {
80+
if (t.value == c) {
81+
return t;
82+
}
83+
}
84+
throw new IllegalArgumentException(
85+
"DataType " + c + " is not recognized in Java (version " + TensorFlow.version() + ")");
86+
}
87+
88+
/**
89+
* Returns the DataType of a Tensor whose elements have the type specified by class {@code c}.
90+
*
91+
* @param c The class describing the TensorFlow type of interest.
92+
* @return The {@code DataType} enum corresponding to {@code c}.
93+
* @throws IllegalArgumentException if objects of {@code c} do not correspond to a TensorFlow
94+
* datatype.
95+
*/
96+
public static DataType fromClass(Class<?> c) {
97+
DataType dtype = typeCodes.get(c);
98+
if (dtype == null) {
99+
throw new IllegalArgumentException(
100+
c.getName() + " objects cannot be used as elements in a TensorFlow Tensor");
101+
}
102+
return dtype;
103+
}
104+
105+
private static final Map<Class<?>, DataType> typeCodes = new HashMap<>();
106+
107+
static {
108+
typeCodes.put(Float.class, DataType.FLOAT);
109+
typeCodes.put(Double.class, DataType.DOUBLE);
110+
typeCodes.put(Integer.class, DataType.INT32);
111+
typeCodes.put(UInt8.class, DataType.UINT8);
112+
typeCodes.put(Long.class, DataType.INT64);
113+
typeCodes.put(Boolean.class, DataType.BOOL);
114+
typeCodes.put(String.class, DataType.STRING);
115+
}
116+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
package org.tensorflow;
17+
18+
import java.util.concurrent.atomic.AtomicReferenceArray;
19+
20+
/**
21+
* Implementation of an {@link Operation} executed eagerly.
22+
*
23+
* <p>EagerOperation instances are valid only as long as the {@link EagerSession} they are a part of
24+
* is valid. Thus, if {@link EagerSession#close()} has been invoked, then methods on the
25+
* EagerOperation instance may fail with an {@code IllegalStateException}.
26+
*
27+
* <p>EagerOperation instances are thread-safe.
28+
*/
29+
class EagerOperation extends AbstractOperation {
30+
31+
EagerOperation(
32+
EagerSession session,
33+
long opNativeHandle,
34+
long[] outputNativeHandles,
35+
String type,
36+
String name) {
37+
this.session = session;
38+
this.type = type;
39+
this.name = name;
40+
this.nativeRef = new NativeReference(session, this, opNativeHandle, outputNativeHandles);
41+
this.outputTensors = new AtomicReferenceArray<Tensor<?>>(outputNativeHandles.length);
42+
}
43+
44+
@Override
45+
public String name() {
46+
return name;
47+
}
48+
49+
@Override
50+
public String type() {
51+
return type;
52+
}
53+
54+
@Override
55+
public int numOutputs() {
56+
return nativeRef.outputHandles.length;
57+
}
58+
59+
@Override
60+
public int outputListLength(final String name) {
61+
return outputListLength(nativeRef.opHandle, name);
62+
}
63+
64+
@Override
65+
public int inputListLength(final String name) {
66+
return inputListLength(nativeRef.opHandle, name);
67+
}
68+
69+
@Override
70+
public long getUnsafeNativeHandle(int outputIndex) {
71+
return nativeRef.outputHandles[outputIndex];
72+
}
73+
74+
@Override
75+
public long[] shape(int outputIndex) {
76+
// If the tensor of this output has already been resolved, return its shape.
77+
// Otherwise, retrieve the tensor shape from the native library.
78+
Tensor<?> tensor = outputTensors.get(outputIndex);
79+
if (tensor != null) {
80+
return tensor.shape();
81+
}
82+
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
83+
long[] shape = new long[numDims(outputNativeHandle)];
84+
for (int i = 0; i < shape.length; ++i) {
85+
shape[i] = dim(outputNativeHandle, i);
86+
}
87+
return shape;
88+
}
89+
90+
@Override
91+
public DataType dtype(int outputIndex) {
92+
// If the tensor of this output has already been resolved, return its datatype.
93+
// Otherwise, retrieve the tensor datatype from the native library.
94+
Tensor<?> tensor = outputTensors.get(outputIndex);
95+
if (tensor != null) {
96+
return tensor.dataType();
97+
}
98+
long outputNativeHandle = getUnsafeNativeHandle(outputIndex);
99+
return DataType.fromC(dataType(outputNativeHandle));
100+
}
101+
102+
@Override
103+
public Tensor<?> tensor(int outputIndex) {
104+
Tensor<?> tensor = outputTensors.get(outputIndex);
105+
if (tensor == null) {
106+
tensor = resolveTensor(outputIndex);
107+
}
108+
return tensor;
109+
}
110+
111+
private final EagerSession session;
112+
private final NativeReference nativeRef;
113+
private final String type;
114+
private final String name;
115+
private final AtomicReferenceArray<Tensor<?>> outputTensors;
116+
117+
private Tensor<?> resolveTensor(int outputIndex) {
118+
// Take an optimistic approach, where we attempt to resolve the output tensor without locking.
119+
// If another thread has resolved it meanwhile, release our copy and reuse the existing one
120+
// instead.
121+
long tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex));
122+
Tensor<?> tensor = Tensor.fromHandle(tensorNativeHandle, session);
123+
if (!outputTensors.compareAndSet(outputIndex, null, tensor)) {
124+
tensor.close();
125+
tensor = outputTensors.get(outputIndex);
126+
}
127+
return tensor;
128+
}
129+
130+
private static class NativeReference extends EagerSession.NativeReference {
131+
132+
NativeReference(
133+
EagerSession session, EagerOperation operation, long opHandle, long[] outputHandles) {
134+
super(session, operation);
135+
this.opHandle = opHandle;
136+
this.outputHandles = outputHandles;
137+
}
138+
139+
@Override
140+
void delete() {
141+
if (opHandle != 0L) {
142+
for (int i = 0; i < outputHandles.length; ++i) {
143+
if (outputHandles[i] != 0L) {
144+
EagerOperation.deleteTensorHandle(outputHandles[i]);
145+
outputHandles[i] = 0L;
146+
}
147+
}
148+
EagerOperation.delete(opHandle);
149+
opHandle = 0L;
150+
}
151+
}
152+
153+
private long opHandle;
154+
private final long[] outputHandles;
155+
}
156+
157+
private static native void delete(long handle);
158+
159+
private static native void deleteTensorHandle(long handle);
160+
161+
private static native long resolveTensorHandle(long handle);
162+
163+
private static native int outputListLength(long handle, String name);
164+
165+
private static native int inputListLength(long handle, String name);
166+
167+
private static native int dataType(long handle);
168+
169+
private static native int numDims(long handle);
170+
171+
private static native long dim(long handle, int index);
172+
}

0 commit comments

Comments
 (0)