Skip to content

Commit 1d35c17

Browse files
karllessardShajan
andauthored
Create, save, load and run models using a new functional API (tensorflow#112)
* Create, save and load models using functional API Python models that contain tf.function is inconvenient to be consumed by Java clients. This proposal provides an API to (a) Invoke a tf.function, given the signature name (b) Retrieve the node name in the graph corresponding to a tf.function Co-authored-by: Shajan Dasan <[email protected]> Save models as functions (tensorflow#103) * Draft: Java API to use tf.function available on SavedModel. (tensorflow#89) Python models that contain tf.function is inconvenient to be consumed by Java clients. This proposal provides an API to (a) Invoke a tf.function, given the signature name (b) Retrieve the node name in the graph corresponding to a tf.function Co-authored-by: Shajan Dasan <[email protected]> * Change API for creating concrete functions and exporting them to a saved model Co-authored-by: Karl Lessard <[email protected]> Rename signature name to key Print function signature when converting to String Add method that returns the signature of all functions in a saved model Add unit tests for python created SavedModel with tf.function * Add validations on signatures and saved models * Convert text file to Python * Add copyright on Python sample Co-authored-by: Shajan Dasan <[email protected]>
1 parent 2843138 commit 1d35c17

File tree

14 files changed

+1366
-8
lines changed

14 files changed

+1366
-8
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
/*
2+
* Copyright 2020 The TensorFlow Authors. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.tensorflow;
17+
18+
import java.io.IOException;
19+
import java.util.List;
20+
import java.util.ListIterator;
21+
import java.util.HashMap;
22+
import java.util.Map;
23+
import java.util.function.Function;
24+
import org.tensorflow.op.Ops;
25+
import org.tensorflow.proto.framework.SignatureDef;
26+
import org.tensorflow.proto.framework.TensorInfo;
27+
28+
/**
29+
* A graph that can be invoked as a single function, with an input and output signature.
30+
*
31+
* <p>A function can also invoke a
32+
* <a href="https://www.tensorflow.org/api_docs/python/tf/function">tf.function</a>
33+
* defined in a {@link SavedModelBundle}.
34+
*
35+
* <pre>{@code
36+
* ConcreteFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
37+
* Map<String, Tensor<?>> outputTensorMap = myFunction.call(inputTensorMap);
38+
* }</pre>
39+
*/
40+
public class ConcreteFunction implements AutoCloseable {
41+
42+
/**
43+
* Creates a function by building a new graph.
44+
*
45+
* <p/>The {@code functionBuilder} must initialize the function graph from the provided
46+
* {@link Ops} instance and return a valid signature that will be used to feed the input tensors
47+
* and fetch the output tensors on execution.
48+
*
49+
* <p/>The function will be the owner of the new graph and its resulting session. Therefore,
50+
* the function must be enclosed properly with a try-with-resources block to guarantee that
51+
* all native resources will be freed once the function is discarded. For example:
52+
*
53+
* <pre>{@code
54+
* public class MyModel {
55+
*
56+
* public static Signature addTwo(Ops tf) {
57+
* Placeholder<TFloat32> input = tf.placeholder(TFloat32.DTYPE);
58+
* Add<TFloat32> output = tf.math.add(input, tf.constant(2.0f));
59+
* return Signature.builder("addTwo").input("x", input).output("y", output).build();
60+
* }
61+
*
62+
* public static void main(String args[]) {
63+
* try (ConcreteFunction function = ConcreteFunction.create(MyModel::addTwo);
64+
* Tensor<TFloat32> x = TFloat32.scalarOf(2.0f)) {
65+
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
66+
* }
67+
* }
68+
* }
69+
* }</pre>
70+
*
71+
* @param functionBuilder function builder
72+
* @return the new function
73+
*/
74+
public static ConcreteFunction create(Function<Ops, Signature> functionBuilder) {
75+
Graph graph = new Graph();
76+
try {
77+
Ops tf = Ops.create(graph);
78+
Signature signature = functionBuilder.apply(tf);
79+
return new ConcreteFunction(signature, graph, new Session(graph), Ownership.GRAPH_AND_SESSION);
80+
} catch (Exception e) {
81+
graph.close();
82+
throw e;
83+
}
84+
}
85+
86+
/**
87+
* Create a function from a signature and an existing graph.
88+
*
89+
* <p/>The function will keep the ownership of the session used to run the graph but not
90+
* the graph itself, meaning that the lifetime of the latter can extend beyond the scope
91+
* of the function. For example:
92+
*
93+
* <pre>{@code
94+
* try (Graph g = new Graph()) {
95+
* Placeholder<TFloat32> input = tf.placeholder(TFloat32.DTYPE);
96+
* Add<TFloat32> output = tf.math.add(input, tf.constant(2.0f));
97+
* Signature signature = Signature.builder().input("x", input).output("y", output).build();
98+
*
99+
* try (ConcreteFunction f = ConcreteFunction.create(signature, g);
100+
* Tensor<TFloat32> x = TFloat32.scalarOf(2.0f)) {
101+
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
102+
* }
103+
* // Graph g is still valid at this point
104+
* }
105+
* }</pre>
106+
*
107+
* @param signature signature of the function to create
108+
* @param graph a valid and initialized graph
109+
* @return a new function
110+
*/
111+
public static ConcreteFunction create(Signature signature, Graph graph) {
112+
return new ConcreteFunction(signature, graph, new Session(graph), Ownership.SESSION_ONLY);
113+
}
114+
115+
/**
116+
* Create a function from a signature and a valid graph session.
117+
*
118+
* <p/>The function will not own the session nor its graph, meaning that their lifetime
119+
* can extend beyond the scope of the function. Therefore the function does not need to be
120+
* closed after its usage. For example:
121+
*
122+
* <pre>{@code
123+
* try (Graph g = new Graph()) {
124+
* Placeholder<TFloat32> input = tf.placeholder(TFloat32.DTYPE);
125+
* Add<TFloat32> output = tf.math.add(input, tf.constant(2.0f));
126+
* Signature signature = Signature.builder().input("x", input).output("y", output).build();
127+
*
128+
* try (Session s = new Session(g)) {
129+
* // Auto-closing the function just as an example but this is not required since it has
130+
* // no effect
131+
* try (ConcreteFunction f = ConcreteFunction.create(signature, s);
132+
* Tensor<TFloat32> t = TFloat32.scalarOf(2.0f)) {
133+
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
134+
* }
135+
* // Session s is still valid at this point
136+
* }
137+
* // Graph g is still valid at this point
138+
* }
139+
* }</pre>
140+
*
141+
* @param signature signature of the function to create
142+
* @param graph a valid session to an initialized graph
143+
* @return a new function
144+
*/
145+
public static ConcreteFunction create(Signature signature, Session session) {
146+
return new ConcreteFunction(signature, session.graph(), session, Ownership.NONE);
147+
}
148+
149+
/**
150+
* Returns the signature of this function
151+
*/
152+
public Signature signature() {
153+
return signature;
154+
}
155+
156+
/**
157+
* Invokes a function.
158+
*
159+
* <p>Caller is responsible for closing all Tensors.
160+
*
161+
* @param tensor input tensor
162+
* @return output tensor
163+
*/
164+
public Map<String, Tensor<?>> call(Map<String, Tensor<?>> arguments)
165+
throws IllegalArgumentException {
166+
167+
final SignatureDef signatureDef = signature.asSignatureDef();
168+
final Session.Runner runner = session.runner();
169+
170+
signatureDef.getInputsMap().forEach((argName, t) -> {
171+
Tensor<?> tensor = arguments.get(argName);
172+
if (tensor == null) {
173+
throw new IllegalArgumentException(String.format("Missing argument [%s]", argName));
174+
}
175+
runner.feed(t.getName(), tensor);
176+
});
177+
178+
Map<String, TensorInfo> outputToNode = signatureDef.getOutputsMap();
179+
outputToNode.values().forEach(t -> runner.fetch(t.getName()));
180+
181+
List<Tensor<?>> resultTensors = runner.run();
182+
try {
183+
ListIterator<Tensor<?>> resultTensorIter = resultTensors.listIterator();
184+
Map<String, Tensor<?>> returnMap = new HashMap<String, Tensor<?>>();
185+
186+
// Use the output names as present in the signature definition
187+
for (String nodeName: outputToNode.keySet()) {
188+
returnMap.put(nodeName, resultTensorIter.next());
189+
}
190+
return returnMap;
191+
192+
} catch (Exception e) {
193+
// Release tensors before throwing exception
194+
for (Tensor<?> t : resultTensors) {
195+
t.close();
196+
}
197+
throw e;
198+
}
199+
}
200+
201+
/**
202+
* Invokes a function with a single input and output.
203+
*
204+
* <p>Caller is responsible for closing all Tensors.
205+
*
206+
* @param tensor input tensor
207+
* @return output tensor
208+
* @throws IllegalArgumentException if there are multiple input or output parameters defined
209+
* in the function
210+
*/
211+
public Tensor<?> call(Tensor<?> tensor) throws IllegalArgumentException {
212+
final SignatureDef signatureDef = signature.asSignatureDef();
213+
214+
if (signatureDef.getInputsCount() != 1) {
215+
throw new IllegalArgumentException(
216+
String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName()));
217+
}
218+
String inputNodeName = signatureDef.getInputsMap().values().iterator().next().getName();
219+
220+
if (signatureDef.getOutputsCount() != 1) {
221+
throw new IllegalArgumentException(
222+
String.format("Function [%s] has multiple outputs", signatureDef.getMethodName()));
223+
}
224+
String outputNodeName = signatureDef.getOutputsMap().values().iterator().next().getName();
225+
226+
return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0);
227+
}
228+
229+
/**
230+
* Export this function as a saved model.
231+
*
232+
* <p>This method is convenient shortcut equivalent to
233+
* {@code SavedModel.exporter(exportDir).withFunction(this).export()}
234+
*
235+
* @throws IOException if saved model or variable state cannot be written on disk
236+
*/
237+
public void save(String exportDir) throws IOException {
238+
SavedModelBundle.exporter(exportDir).withFunction(this).export();
239+
}
240+
241+
/**
242+
* Returns the session used to execute the graph when calling this function
243+
*
244+
* <p>In general, a user does not need to handle directly the session of a function and rely
245+
* on {@link #call(Map)} to execute the graph instead. But in some cases, direct access to
246+
* the session might be necessary, as it allows more running options.
247+
*
248+
* @return the function session
249+
*/
250+
public Session session() {
251+
return session;
252+
}
253+
254+
/**
255+
* Returns the graph of this function
256+
*/
257+
public Graph graph() {
258+
return graph;
259+
}
260+
261+
@Override
262+
public void close() {
263+
if (ownership != Ownership.NONE) {
264+
session.close();
265+
if (ownership == Ownership.GRAPH_AND_SESSION) {
266+
graph.close();
267+
}
268+
}
269+
}
270+
271+
@Override
272+
public String toString() {
273+
return signature.toString();
274+
}
275+
276+
private enum Ownership {
277+
GRAPH_AND_SESSION, SESSION_ONLY, NONE;
278+
}
279+
280+
private final Graph graph;
281+
private final Session session;
282+
private final Signature signature;
283+
private final Ownership ownership;
284+
285+
ConcreteFunction(Signature signature, Graph graph, Session session, Ownership ownership) {
286+
this.graph = graph;
287+
this.session = session;
288+
this.signature = signature;
289+
this.ownership = ownership;
290+
}
291+
}

0 commit comments

Comments
 (0)