Skip to content

Commit 8083080

Browse files
authored
Accept passing maximum_iterations as a Tensor (#1751)
Also add `@typechecked` to dynamic_decode for consistency.
1 parent 4979b46 commit 8083080

File tree

2 files changed

+7
-10
lines changed

2 files changed

+7
-10
lines changed

tensorflow_addons/seq2seq/decoder.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import abc
1818

1919
import tensorflow as tf
20-
from tensorflow_addons.utils.types import Number
20+
from tensorflow_addons.utils.types import TensorLike
2121
from typeguard import typechecked
2222
from typing import Any, Optional, Tuple, Union
2323

@@ -144,7 +144,7 @@ def __init__(
144144
self,
145145
output_time_major: bool = False,
146146
impute_finished: bool = False,
147-
maximum_iterations: Optional[Number] = None,
147+
maximum_iterations: Optional[TensorLike] = None,
148148
parallel_iterations: int = 32,
149149
swap_memory: bool = False,
150150
**kwargs
@@ -256,11 +256,12 @@ def tracks_own_finished(self):
256256
# TODO(scottzhu): Add build/get_config/from_config and other layer methods.
257257

258258

259+
@typechecked
259260
def dynamic_decode(
260261
decoder: Union[Decoder, BaseDecoder],
261262
output_time_major: bool = False,
262263
impute_finished: bool = False,
263-
maximum_iterations: Optional[Number] = None,
264+
maximum_iterations: Optional[TensorLike] = None,
264265
parallel_iterations: int = 32,
265266
swap_memory: bool = False,
266267
training: Optional[bool] = None,
@@ -299,14 +300,8 @@ def dynamic_decode(
299300
`(final_outputs, final_state, final_sequence_lengths)`.
300301
301302
Raises:
302-
TypeError: if `decoder` is not an instance of `Decoder`.
303303
ValueError: if `maximum_iterations` is provided but is not a scalar.
304304
"""
305-
if not isinstance(decoder, (Decoder, BaseDecoder)):
306-
raise TypeError(
307-
"Expected decoder to be type Decoder, but saw: %s" % type(decoder)
308-
)
309-
310305
with tf.compat.v1.variable_scope(scope, "decoder") as varscope:
311306
# Determine context types.
312307
ctxt = tf.compat.v1.get_default_graph()._get_control_flow_context()

tensorflow_addons/seq2seq/tests/decoder_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323

2424

2525
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
26-
@pytest.mark.parametrize("maximum_iterations", [None, 0, 1])
26+
@pytest.mark.parametrize(
27+
"maximum_iterations", [None, 0, 1, tf.constant(1, dtype=tf.int32)]
28+
)
2729
@pytest.mark.parametrize("time_major", [True, False])
2830
def test_dynamic_decode_rnn(time_major, maximum_iterations):
2931

0 commit comments

Comments
 (0)