Skip to content

Commit d42cdcc

Browse files
jonpsyWindQAQ
andauthored
Switch custom ops (#2397)
* Add methods for switching custom kernels * Motivate usage in codebase * -custom_kernel_disabled * - doc fix * fix readme fix code doc resource_loader.py fix code doc options.py * is_custom_disabled * - is_custom_kernel_disabled() * - rm disable from gelu.py * - rm pass from options.py * - change docs of options.py disable & enable * - options.py is_custom_kernel docfix * - doc fix resource_loader.py * - fix README.md doc * - previous_py_ops_value test_utils.py * - format README.md * - rm TF_ADDONS_PY_OPS in distort_image_ops.py * Update README.md Co-authored-by: Tzu-Wei Sung <[email protected]> * Update README.md Co-authored-by: Tzu-Wei Sung <[email protected]> * Update tensorflow_addons/options.py Co-authored-by: Tzu-Wei Sung <[email protected]> * Update tensorflow_addons/options.py Co-authored-by: Tzu-Wei Sung <[email protected]> * Update tensorflow_addons/utils/resource_loader.py Co-authored-by: Tzu-Wei Sung <[email protected]> * Update tensorflow_addons/utils/test_utils.py Co-authored-by: Tzu-Wei Sung <[email protected]> * Update tensorflow_addons/options.py Co-authored-by: Tzu-Wei Sung <[email protected]> * Fix typo and format * Expose options to public * Hide some internal uses * Wording * Wording Co-authored-by: Tzu-Wei Sung <[email protected]>
1 parent dd2090c commit d42cdcc

File tree

9 files changed

+55
-25
lines changed

9 files changed

+55
-25
lines changed

README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,9 @@ The order of priority on Linux is:
194194
2) C++ implementation
195195
3) Pure TensorFlow + Python implementation (works on CPU and GPU)
196196

197-
If you want to change the default priority, "C++ and CUDA" VS "pure TensorFlow Python",
198-
you can set the variable `TF_ADDONS_PY_OPS` either from the command line or in
199-
your code.
197+
If you want to change the default priority, "C++ and CUDA" VS "pure TensorFlow Python",
198+
you can set the environment variable `TF_ADDONS_PY_OPS=1` from the command line or
199+
run `tfa.options.disable_custom_kernel()` in your code.
200200

201201
For example, if you are on Linux and you have compatibility problems with the compiled ops,
202202
you can give priority to the Python implementations:
@@ -210,7 +210,7 @@ or in your code:
210210

211211
```python
212212
import tensorflow_addons as tfa
213-
tfa.options.TF_ADDONS_PY_OPS = True
213+
tfa.options.disable_custom_kernel()
214214
```
215215

216216
This variable defaults to `True` on Windows and macOS, and `False` on Linux.

tensorflow_addons/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tensorflow_addons import rnn
2929
from tensorflow_addons import seq2seq
3030
from tensorflow_addons import text
31+
from tensorflow_addons import options
3132
from tensorflow_addons.register import register_all
3233
from tensorflow_addons.utils import types
3334

tensorflow_addons/image/distort_image_ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def adjust_hsv_in_yiq(
183183
)
184184
scale_value = tf.cast(scale_value, dtype=image.dtype, name="scale_value")
185185

186-
if not options.TF_ADDONS_PY_OPS:
186+
if not options.is_custom_kernel_disabled():
187187
warnings.warn(
188188
"C++/CUDA kernel of `adjust_hsv_in_yiq` will be removed in Addons `0.13`.",
189189
DeprecationWarning,

tensorflow_addons/options.py

+34-12
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
import traceback
55

66
try:
7-
TF_ADDONS_PY_OPS = bool(int(os.environ["TF_ADDONS_PY_OPS"]))
7+
_TF_ADDONS_PY_OPS = bool(int(os.environ["TF_ADDONS_PY_OPS"]))
88
except KeyError:
99
if platform.system() == "Linux":
10-
TF_ADDONS_PY_OPS = False
10+
_TF_ADDONS_PY_OPS = False
1111
else:
12-
TF_ADDONS_PY_OPS = True
12+
_TF_ADDONS_PY_OPS = True
1313

14-
15-
FALLBACK_WARNING_TEMPLATE = """{}
14+
_FALLBACK_WARNING_TEMPLATE = """{}
1615
1716
The {} C++/CUDA custom op could not be loaded.
1817
For this reason, Addons will fallback to an implementation written
@@ -26,24 +25,47 @@
2625
If you want this warning to disappear, either make sure the TensorFlow installed
2726
is compatible with this version of Addons, or tell TensorFlow Addons to
2827
prefer using Python implementations and not custom C++/CUDA ones. You can do that
29-
by changing the TF_ADDONS_PY_OPS flag
30-
either with the environment variable:
28+
by setting the enviornment variable `TF_ADDONS_PY_OPS=1`:
3129
```bash
3230
TF_ADDONS_PY_OPS=1 python my_script.py
3331
```
34-
or in your code, after your imports:
32+
or run `tfa.options.disable_custom_kernel()` in your code, after your imports:
3533
```python
3634
import tensorflow_addons as tfa
3735
import ...
3836
import ...
3937
40-
tfa.options.TF_ADDONS_PY_OPS = True
38+
tfa.options.disable_custom_kernel()
4139
```
4240
"""
4341

4442

4543
def warn_fallback(op_name):
46-
warning_msg = FALLBACK_WARNING_TEMPLATE.format(traceback.format_exc(), op_name)
44+
warning_msg = _FALLBACK_WARNING_TEMPLATE.format(traceback.format_exc(), op_name)
4745
warnings.warn(warning_msg, RuntimeWarning)
48-
global TF_ADDONS_PY_OPS
49-
TF_ADDONS_PY_OPS = True
46+
disable_custom_kernel()
47+
48+
49+
def enable_custom_kernel():
50+
"""Prefer custom C++/CUDA kernel to pure python operations.
51+
52+
Enable using custom C++/CUDA kernel instead of pure python operations.
53+
It has the same effect as setting environment variable `TF_ADDONS_PY_OPS=0`.
54+
"""
55+
global _TF_ADDONS_PY_OPS
56+
_TF_ADDONS_PY_OPS = False
57+
58+
59+
def disable_custom_kernel():
60+
"""Prefer pure python operations to custom C++/CUDA kernel.
61+
62+
Disable using custom C++/CUDA kernel instead of pure python operations.
63+
It has the same effect as setting environment variable `TF_ADDONS_PY_OPS=1`.
64+
"""
65+
global _TF_ADDONS_PY_OPS
66+
_TF_ADDONS_PY_OPS = True
67+
68+
69+
def is_custom_kernel_disabled():
70+
"""Return whether custom C++/CUDA kernel is disabled."""
71+
return _TF_ADDONS_PY_OPS

tensorflow_addons/seq2seq/beam_search_decoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def gather_tree(
234234
Raises:
235235
InvalidArgumentError: if `parent_ids` contains an invalid index.
236236
"""
237-
if not options.TF_ADDONS_PY_OPS:
237+
if not options.is_custom_kernel_disabled():
238238
try:
239239
return _beam_search_so.ops.addons_gather_tree(
240240
step_ids, parent_ids, max_sequence_lengths, end_token

tensorflow_addons/seq2seq/tests/beam_search_ops_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def test_bad_parent_values_on_gpu():
8383
expected_result = _transpose_batch_time(
8484
[[[2, -1, 2], [6, 5, 6], [7, 8, 9], [10, 10, 10]]]
8585
)
86-
if options.TF_ADDONS_PY_OPS:
86+
if options.is_custom_kernel_disabled():
8787
# The Python version has the same behavior on CPU and GPU.
8888
with pytest.raises(tf.errors.InvalidArgumentError, match="parent id"):
8989
_ = gather_tree(

tensorflow_addons/utils/resource_loader.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def display_warning_if_incompatible(self):
8686
"on Github. This is a known limitation."
8787
"\n\n"
8888
"It might help you to fallback to pure Python "
89-
"ops with TF_ADDONS_PY_OPS . To do that, see "
89+
"ops by setting environment variable `TF_ADDONS_PY_OPS=1` or using `tfa.options.disable_custom_kernel()` in your code. "
90+
"To do that, see "
9091
"https://github.com/tensorflow/addons#gpucpu-custom-ops "
9192
"\n\n"
9293
"You can also change the TensorFlow version installed on your system. "

tensorflow_addons/utils/test_utils.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,17 @@ def only_run_functions_eagerly(request):
9595

9696
@pytest.fixture(scope="function", params=["custom_ops", "py_ops"])
9797
def run_custom_and_py_ops(request):
98-
previous_py_ops_value = options.TF_ADDONS_PY_OPS
98+
previous_is_custom_kernel_disabled = options.is_custom_kernel_disabled()
9999
if request.param == "custom_ops":
100-
options.TF_ADDONS_PY_OPS = False
100+
options.enable_custom_kernel()
101101
elif request.param == "py_ops":
102-
options.TF_ADDONS_PY_OPS = True
102+
options.disable_custom_kernel()
103103

104104
def _restore_py_ops_value():
105-
options.TF_ADDONS_PY_OPS = previous_py_ops_value
105+
if previous_is_custom_kernel_disabled:
106+
options.disable_custom_kernel()
107+
else:
108+
options.enable_custom_kernel()
106109

107110
request.addfinalizer(_restore_py_ops_value)
108111

tools/docs/build_docs.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ def main(argv):
7373
root_title=PROJECT_FULL_NAME,
7474
py_modules=[(PROJECT_SHORT_NAME, tfa)],
7575
code_url_prefix=code_url_prefix,
76-
private_map={"tfa": ["__version__", "utils", "version"]},
76+
private_map={
77+
"tfa": ["__version__", "utils", "version"],
78+
"tfa.options": ["warn_fallback"],
79+
},
7780
# These callbacks usually clean up a lot of aliases caused by internal imports.
7881
callbacks=[
7982
public_api.local_definitions_filter,

0 commit comments

Comments
 (0)