|
17 | 17 | import abc
|
18 | 18 |
|
19 | 19 | import tensorflow as tf
|
20 |
| -from tensorflow_addons.utils.types import Number |
| 20 | +from tensorflow_addons.utils.types import TensorLike |
21 | 21 | from typeguard import typechecked
|
22 | 22 | from typing import Any, Optional, Tuple, Union
|
23 | 23 |
|
@@ -144,7 +144,7 @@ def __init__(
|
144 | 144 | self,
|
145 | 145 | output_time_major: bool = False,
|
146 | 146 | impute_finished: bool = False,
|
147 |
| - maximum_iterations: Optional[Number] = None, |
| 147 | + maximum_iterations: Optional[TensorLike] = None, |
148 | 148 | parallel_iterations: int = 32,
|
149 | 149 | swap_memory: bool = False,
|
150 | 150 | **kwargs
|
@@ -256,11 +256,12 @@ def tracks_own_finished(self):
|
256 | 256 | # TODO(scottzhu): Add build/get_config/from_config and other layer methods.
|
257 | 257 |
|
258 | 258 |
|
| 259 | +@typechecked |
259 | 260 | def dynamic_decode(
|
260 | 261 | decoder: Union[Decoder, BaseDecoder],
|
261 | 262 | output_time_major: bool = False,
|
262 | 263 | impute_finished: bool = False,
|
263 |
| - maximum_iterations: Optional[Number] = None, |
| 264 | + maximum_iterations: Optional[TensorLike] = None, |
264 | 265 | parallel_iterations: int = 32,
|
265 | 266 | swap_memory: bool = False,
|
266 | 267 | training: Optional[bool] = None,
|
@@ -299,14 +300,8 @@ def dynamic_decode(
|
299 | 300 | `(final_outputs, final_state, final_sequence_lengths)`.
|
300 | 301 |
|
301 | 302 | Raises:
|
302 |
| - TypeError: if `decoder` is not an instance of `Decoder`. |
303 | 303 | ValueError: if `maximum_iterations` is provided but is not a scalar.
|
304 | 304 | """
|
305 |
| - if not isinstance(decoder, (Decoder, BaseDecoder)): |
306 |
| - raise TypeError( |
307 |
| - "Expected decoder to be type Decoder, but saw: %s" % type(decoder) |
308 |
| - ) |
309 |
| - |
310 | 305 | with tf.compat.v1.variable_scope(scope, "decoder") as varscope:
|
311 | 306 | # Determine context types.
|
312 | 307 | ctxt = tf.compat.v1.get_default_graph()._get_control_flow_context()
|
|
0 commit comments