Skip to content

Commit e0eec4a

Browse files
authored
Graph custom gradient support (#292)
1 parent 55547dd commit e0eec4a

File tree

1,367 files changed

+17408
-2986
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,367 files changed

+17408
-2986
lines changed

tensorflow-core/tensorflow-core-api/WORKSPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ http_archive(
1212
# ":tensorflow-macosx.patch",
1313
# ":tensorflow-windows.patch", # https://github.com/tensorflow/tensorflow/issues/25213
1414
":tensorflow-proto.patch",
15+
":custom-grad-helpers.patch",
16+
":custom-grad-symbols.patch",
1517
],
1618
patch_tool = "patch",
1719
patch_args = ["-p1"],
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
2+
index f3bf7b98a1e6b..c9194c36c116b 100644
3+
--- a/tensorflow/c/c_api.cc
4+
+++ b/tensorflow/c/c_api.cc
5+
@@ -782,9 +782,9 @@ void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
6+
7+
extern "C" {
8+
9+
-static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
10+
- const char* op_type,
11+
- const char* oper_name)
12+
+TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
13+
+ const char* op_type,
14+
+ const char* oper_name)
15+
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
16+
return new TF_OperationDescription(graph, op_type, oper_name);
17+
}
18+
@@ -1041,8 +1041,8 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
19+
status->status = Status::OK();
20+
}
21+
22+
-static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
23+
- TF_Status* status)
24+
+TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
25+
+ TF_Status* status)
26+
TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
27+
Node* ret = nullptr;
28+
29+
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
30+
index 705cf85e0512f..fb746dd4af94f 100644
31+
--- a/tensorflow/c/c_api.h
32+
+++ b/tensorflow/c/c_api.h
33+
@@ -255,6 +255,12 @@ TF_CAPI_EXPORT extern void TF_GraphGetTensorShape(TF_Graph* graph,
34+
int64_t* dims, int num_dims,
35+
TF_Status* status);
36+
37+
+// TF_NewOperation, but without locking the graph.
38+
+// Should prefer TF_NewOperation when possible.
39+
+TF_CAPI_EXPORT extern TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
40+
+ const char* op_type,
41+
+ const char* oper_name);
42+
+
43+
// Operation will only be added to *graph when TF_FinishOperation() is
44+
// called (assuming TF_FinishOperation() does not return an error).
45+
// *graph must not be deleted until after TF_FinishOperation() is
46+
@@ -406,6 +412,11 @@ TF_CAPI_EXPORT extern void TF_SetAttrValueProto(TF_OperationDescription* desc,
47+
size_t proto_len,
48+
TF_Status* status);
49+
50+
+// TF_FinishOperation, but without locking the graph.
51+
+// TF_FinishOperation should be preferred when possible.
52+
+TF_CAPI_EXPORT extern TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
53+
+ TF_Status* status);
54+
+
55+
// If this function succeeds:
56+
// * *status is set to an OK value,
57+
// * a TF_Operation is added to the graph,
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
Index: tensorflow/tools/def_file_filter/BUILD
2+
IDEA additional info:
3+
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
4+
<+>UTF-8
5+
===================================================================
6+
diff --git a/tensorflow/tools/def_file_filter/BUILD b/tensorflow/tools/def_file_filter/BUILD
7+
--- a/tensorflow/tools/def_file_filter/BUILD (revision 5e5cc35b4c0f629a1e092b540fdf2b63367aa5ad)
8+
+++ b/tensorflow/tools/def_file_filter/BUILD (date 1629063191558)
9+
@@ -12,3 +12,8 @@
10+
name = "symbols_pybind",
11+
srcs = ["symbols_pybind.txt"],
12+
)
13+
+
14+
+filegroup(
15+
+ name = "symbols_java",
16+
+ srcs = ["symbols_java.txt"],
17+
+)
18+
Index: tensorflow/BUILD
19+
IDEA additional info:
20+
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
21+
<+>UTF-8
22+
===================================================================
23+
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
24+
--- a/tensorflow/BUILD (revision 5e5cc35b4c0f629a1e092b540fdf2b63367aa5ad)
25+
+++ b/tensorflow/BUILD (date 1629063361078)
26+
@@ -1069,13 +1069,20 @@
27+
# the dynamic libraries of custom ops can find it at runtime.
28+
genrule(
29+
name = "tensorflow_filtered_def_file",
30+
- srcs = [":tensorflow_def_file"],
31+
+ srcs = [
32+
+ ":tensorflow_def_file",
33+
+ ":java_symbol_target_libs_file",
34+
+ ":win_lib_files_for_java_exported_symbols",
35+
+ "//tensorflow/tools/def_file_filter:symbols_java",
36+
+ ],
37+
outs = ["tensorflow_filtered_def_file.def"],
38+
cmd = select({
39+
"//tensorflow:windows": """
40+
$(location @local_config_def_file_filter//:def_file_filter) \\
41+
--input $(location :tensorflow_def_file) \\
42+
- --output $@
43+
+ --output $@ \\
44+
+ --symbols $(location //tensorflow/tools/def_file_filter:symbols_java) \\
45+
+ --lib_paths_file $(location :java_symbol_target_libs_file)
46+
""",
47+
"//conditions:default": "touch $@", # Just a placeholder for Unix platforms
48+
}),
49+
@@ -1083,6 +1090,34 @@
50+
visibility = ["//visibility:public"],
51+
)
52+
53+
+# Write to a file a list of all cc_library targets that we need for exporting symbols on Windows.
54+
+genrule(
55+
+ name = "java_symbol_target_libs_file",
56+
+ srcs = [":win_lib_files_for_java_exported_symbols"],
57+
+ outs = ["java_symbol_target_libs_file.txt"],
58+
+ cmd = select({
59+
+ "//tensorflow:windows": """
60+
+ for SRC in $(SRCS); do
61+
+ echo $$SRC | sed 's/third_party\\///g' >> $@
62+
+ done
63+
+ """,
64+
+ "//conditions:default": "touch $@", # Just a placeholder for Unix platforms
65+
+ }),
66+
+ visibility = ["//visibility:public"],
67+
+)
68+
+
69+
+filegroup(
70+
+ name = "win_lib_files_for_java_exported_symbols",
71+
+ srcs = [
72+
+ "//tensorflow/cc:scope",
73+
+ "//tensorflow/cc:grad_op_registry",
74+
+ "//tensorflow/c:tf_status_helper",
75+
+ "//tensorflow/cc:ops"
76+
+ ],
77+
+ visibility = ["//visibility:private"],
78+
+)
79+
+
80+
+
81+
# The interface library (tensorflow.dll.if.lib) for linking tensorflow DLL library (tensorflow.dll) on Windows.
82+
# To learn more about import library (called interface library in Bazel):
83+
# https://docs.microsoft.com/en-us/cpp/build/linking-an-executable-to-a-dll?view=vs-2017#linking-implicitly
84+
Index: tensorflow/tools/def_file_filter/BUILD.tpl
85+
IDEA additional info:
86+
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
87+
<+>UTF-8
88+
===================================================================
89+
diff --git a/tensorflow/tools/def_file_filter/BUILD.tpl b/tensorflow/tools/def_file_filter/BUILD.tpl
90+
--- a/tensorflow/tools/def_file_filter/BUILD.tpl (revision 5e5cc35b4c0f629a1e092b540fdf2b63367aa5ad)
91+
+++ b/tensorflow/tools/def_file_filter/BUILD.tpl (date 1629063191583)
92+
@@ -18,3 +18,8 @@
93+
name = "symbols_pybind",
94+
srcs = ["symbols_pybind.txt"],
95+
)
96+
+
97+
+filegroup(
98+
+ name = "symbols_java",
99+
+ srcs = ["symbols_java.txt"],
100+
+)
101+
Index: tensorflow/tools/def_file_filter/symbols_java.txt
102+
IDEA additional info:
103+
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
104+
<+>UTF-8
105+
===================================================================
106+
diff --git a/tensorflow/tools/def_file_filter/symbols_java.txt b/tensorflow/tools/def_file_filter/symbols_java.txt
107+
new file mode 100644
108+
--- /dev/null (date 1629063607794)
109+
+++ b/tensorflow/tools/def_file_filter/symbols_java.txt (date 1629063607794)
110+
@@ -0,0 +1,26 @@
111+
+[//tensorflow/cc:scope] # scope
112+
+tensorflow::Scope::graph
113+
+tensorflow::Scope::ok
114+
+tensorflow::Scope::UpdateBuilder
115+
+tensorflow::Scope::GetUniqueNameForOp
116+
+tensorflow::Scope::ExitOnError
117+
+tensorflow::Scope::WithDevice
118+
+tensorflow::Scope::WithNoControlDependencies
119+
+tensorflow::Scope::WithControlDependencies
120+
+tensorflow::Scope::NewSubScope
121+
+tensorflow::Scope::NewRootScope
122+
+tensorflow::Scope::operator=
123+
+tensorflow::Scope::~Scope
124+
+tensorflow::Scope::Scope
125+
+
126+
+[//tensorflow/cc:ops]
127+
+tensorflow::Operation::Operation
128+
+
129+
+[//tensorflow/cc:grad_op_registry] # custom gradients for graph
130+
+tensorflow::ops::GradOpRegistry::Global
131+
+tensorflow::ops::GradOpRegistry::Lookup
132+
+tensorflow::ops::GradOpRegistry::Register
133+
+
134+
+[//tensorflow/c:tf_status_helper] # status helpers
135+
+tensorflow::Set_TF_Status_from_Status
136+
+tensorflow::StatusFromTF_Status
137+
===================================================================
138+
diff --git a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
139+
--- a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl (revision 919f693420e35d00c8d0a42100837ae3718f7927)
140+
+++ b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl (date 1632048268359)
141+
@@ -143,8 +143,8 @@
142+
re_filter_comp = re.compile(r"{}".format(re_filter))
143+
144+
# Filter out symbol from the split line (`sym_split` in the for loop below).
145+
- sym_line_filter = r".*\s+\| (.*) \(.*"
146+
- sym_line_filter_anomaly = r".*\s+\| (.*)"
147+
+ sym_line_filter = r".*\s+\| (.*?) \(.*"
148+
+ sym_line_filter_anomaly = r".*\s+\| (.*?)"
149+
150+
for sym_line in sym_split:
151+
if re_filter_comp.search(sym_line):

tensorflow-core/tensorflow-core-api/pom.xml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,19 @@
143143
</execution>
144144
</executions>
145145
</plugin>
146+
<plugin>
147+
<artifactId>maven-resources-plugin</artifactId>
148+
<version>3.1.0</version>
149+
<executions>
150+
<execution>
151+
<id>javacpp-parser</id>
152+
<phase>generate-sources</phase>
153+
<goals>
154+
<goal>resources</goal>
155+
</goals>
156+
</execution>
157+
</executions>
158+
</plugin>
146159
<plugin>
147160
<artifactId>maven-compiler-plugin</artifactId>
148161
<version>3.8.0</version>
@@ -212,6 +225,11 @@
212225
<includePaths>
213226
<includePath>${project.basedir}/</includePath>
214227
<includePath>${project.basedir}/bazel-${project.artifactId}/external/org_tensorflow/</includePath>
228+
<includePath>${project.basedir}/bazel-bin/external/org_tensorflow/</includePath>
229+
<includePath>${project.basedir}/bazel-${project.artifactId}/external/com_google_absl/</includePath>
230+
<includePath>${project.basedir}/bazel-${project.artifactId}/external/eigen_archive/</includePath>
231+
<includePath>${project.basedir}/bazel-${project.artifactId}/external/com_google_protobuf/src/</includePath>
232+
<includePath>${project.basedir}/target/classes/org/tensorflow/internal/c_api/include/</includePath>
215233
</includePaths>
216234
<linkPaths>
217235
<linkPath>${project.basedir}/bazel-bin/external/llvm_openmp/</linkPath>

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,10 @@ public final class Ops {
367367

368368
public final SparseOps sparse;
369369

370-
public final BitwiseOps bitwise;
371-
372370
public final TpuOps tpu;
373371

372+
public final BitwiseOps bitwise;
373+
374374
public final MathOps math;
375375

376376
public final AudioOps audio;
@@ -383,7 +383,7 @@ public final class Ops {
383383

384384
private final Scope scope;
385385

386-
private Ops(Scope scope) {
386+
Ops(Scope scope) {
387387
this.scope = scope;
388388
nn = new NnOps(this);
389389
summary = new SummaryOps(this);
@@ -398,8 +398,8 @@ private Ops(Scope scope) {
398398
random = new RandomOps(this);
399399
strings = new StringsOps(this);
400400
sparse = new SparseOps(this);
401-
bitwise = new BitwiseOps(this);
402401
tpu = new TpuOps(this);
402+
bitwise = new BitwiseOps(this);
403403
math = new MathOps(this);
404404
audio = new AudioOps(this);
405405
signal = new SignalOps(this);
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE
2+
3+
package org.tensorflow.internal.c_api;
4+
5+
import java.nio.*;
6+
import org.bytedeco.javacpp.*;
7+
import org.bytedeco.javacpp.annotation.*;
8+
9+
import static org.tensorflow.internal.c_api.global.tensorflow.*;
10+
11+
12+
/** GradFunc is the signature for all gradient functions in GradOpRegistry.
13+
* Implementations should add operations to compute the gradient outputs of
14+
* 'op' (returned in 'grad_outputs') using 'scope' and 'grad_inputs'. */
15+
@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
16+
public class GradFunc extends FunctionPointer {
17+
static { Loader.load(); }
18+
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
19+
public GradFunc(Pointer p) { super(p); }
20+
protected GradFunc() { allocate(); }
21+
private native void allocate();
22+
public native @ByVal NativeStatus call(@Const @ByRef TF_Scope scope, @Const @ByRef NativeOperation op,
23+
@Const @ByRef NativeOutputVector grad_inputs,
24+
NativeOutputVector grad_outputs);
25+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE
2+
3+
package org.tensorflow.internal.c_api;
4+
5+
import java.nio.*;
6+
import org.bytedeco.javacpp.*;
7+
import org.bytedeco.javacpp.annotation.*;
8+
9+
import static org.tensorflow.internal.c_api.global.tensorflow.*;
10+
11+
12+
/** GradOpRegistry maintains a static registry of gradient functions.
13+
* Gradient functions are indexed in the registry by the forward op name (i.e.
14+
* "MatMul" -> MatMulGrad func). */
15+
@Namespace("tensorflow::ops") @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
16+
public class GradOpRegistry extends Pointer {
17+
static { Loader.load(); }
18+
/** Default native constructor. */
19+
public GradOpRegistry() { super((Pointer)null); allocate(); }
20+
/** Native array allocator. Access with {@link Pointer#position(long)}. */
21+
public GradOpRegistry(long size) { super((Pointer)null); allocateArray(size); }
22+
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
23+
public GradOpRegistry(Pointer p) { super(p); }
24+
private native void allocate();
25+
private native void allocateArray(long size);
26+
@Override public GradOpRegistry position(long position) {
27+
return (GradOpRegistry)super.position(position);
28+
}
29+
@Override public GradOpRegistry getPointer(long i) {
30+
return new GradOpRegistry((Pointer)this).offsetAddress(i);
31+
}
32+
33+
/** Registers 'func' as the gradient function for 'op'.
34+
* Returns true if registration was successful, check fails otherwise. */
35+
public native @Cast("bool") boolean Register(@StdString BytePointer op, GradFunc func);
36+
public native @Cast("bool") boolean Register(@StdString String op, GradFunc func);
37+
38+
/** Sets 'func' to the gradient function for 'op' and returns Status OK if
39+
* the gradient function for 'op' exists in the registry.
40+
* Note that 'func' can be null for ops that have registered no-gradient with
41+
* the registry.
42+
* Returns error status otherwise. */
43+
public native @ByVal NativeStatus Lookup(@StdString BytePointer op, @ByPtrPtr GradFunc func);
44+
public native @ByVal NativeStatus Lookup(@StdString String op, @ByPtrPtr GradFunc func);
45+
46+
/** Returns a pointer to the global gradient function registry. */
47+
public static native GradOpRegistry Global();
48+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE
2+
3+
package org.tensorflow.internal.c_api;
4+
5+
import java.nio.*;
6+
import org.bytedeco.javacpp.*;
7+
import org.bytedeco.javacpp.annotation.*;
8+
9+
import static org.tensorflow.internal.c_api.global.tensorflow.*;
10+
11+
@Name("std::unordered_map<tensorflow::string,tensorflow::Node*>") @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
12+
public class NameMap extends Pointer {
13+
static { Loader.load(); }
14+
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
15+
public NameMap(Pointer p) { super(p); }
16+
public NameMap() { allocate(); }
17+
private native void allocate();
18+
public native @Name("operator =") @ByRef NameMap put(@ByRef NameMap x);
19+
20+
public boolean empty() { return size() == 0; }
21+
public native long size();
22+
23+
@Index public native Node get(@StdString BytePointer i);
24+
public native NameMap put(@StdString BytePointer i, Node value);
25+
26+
public native void erase(@ByVal Iterator pos);
27+
public native @ByVal Iterator begin();
28+
public native @ByVal Iterator end();
29+
@NoOffset @Name("iterator") public static class Iterator extends Pointer {
30+
public Iterator(Pointer p) { super(p); }
31+
public Iterator() { }
32+
33+
public native @Name("operator ++") @ByRef Iterator increment();
34+
public native @Name("operator ==") boolean equals(@ByRef Iterator it);
35+
public native @Name("operator *().first") @MemberGetter @StdString BytePointer first();
36+
public native @Name("operator *().second") @MemberGetter @Const Node second();
37+
}
38+
39+
public native long erase(@StdString BytePointer key);
40+
}
41+

0 commit comments

Comments
 (0)