From ee04a6253af915aa4bf587b8a86d6dc8d8c09534 Mon Sep 17 00:00:00 2001 From: Xingyou Song Date: Sun, 7 Jan 2024 11:24:57 -0800 Subject: [PATCH] 1. OSS sequence_utils.py 2. OSS checkify.py 3. Create a more general logit restriction class. PiperOrigin-RevId: 596406264 --- optformer/inference/decoding.py | 66 +++++ optformer/inference/sequence_utils.py | 283 +++++++++++++++++++++ optformer/inference/sequence_utils_test.py | 225 ++++++++++++++++ optformer/validation/checkify.py | 54 ++++ optformer/validation/checkify_test.py | 36 +++ 5 files changed, 664 insertions(+) create mode 100644 optformer/inference/decoding.py create mode 100644 optformer/inference/sequence_utils.py create mode 100644 optformer/inference/sequence_utils_test.py create mode 100644 optformer/validation/checkify.py create mode 100644 optformer/validation/checkify_test.py diff --git a/optformer/inference/decoding.py b/optformer/inference/decoding.py new file mode 100644 index 0000000..c4b38a3 --- /dev/null +++ b/optformer/inference/decoding.py @@ -0,0 +1,66 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Useful decoding-related classes.""" + +import abc +from typing import Optional + +import attrs +import jax.numpy as jnp +from jaxtyping import Array, Float, Int # pylint: disable=g-multiple-import,g-importing-member +from optformer.inference import sequence_utils as seq_utils +from t5x import decoding + + +@attrs.define +class IndexLogitRestrictor(decoding.LogitCallbackFn): + """Restricts logit values depending only on the index.""" + + def __call__( + self, + logits: Float[Array, "BS E"], + state: decoding.SamplingLoopState, + shift: Optional[Int[Array, "B"]] = None, + ) -> Float[Array, "BS E"]: + """Uses shifted current index to obtain logit mask index. + + Args: + logits: Decoder logits used for sampling vocabulary indices at a specific + time-slice. NOTE: `E >= V` assumed, where `E` is last-axis size (size of + embedding table). + state: State of the sampling loop. Most shapes of form [B*S, ...]. + shift: Shift on current index to determine mask index. If `None`, + `mask_index` is defaulted to `state.step`, equivalent to when `shift` is + the start of decoding block (usually the case). + + Returns: + Restricted logits on unmasked tokens. + """ + if shift is None: + mask_index = state.step # Scalar + else: + cur_index = jnp.reshape(state.cur_index, (shift.shape[0], -1)) # [B, S] + mask_index = jnp.reshape(cur_index - shift, (-1,)) # [B*S] + + # Will be broadcasted along the final axis of logits. + curr_mask: Float[Array, "BS V"] = self.logit_mask(mask_index) + # Pad w/ (E-V) zeros to deal w/ extra embeddings. + curr_mask: Float[Array, "BS E"] = seq_utils.rpad(curr_mask, logits) + + return (1.0 - curr_mask) * decoding.NEG_INF + curr_mask * logits + + @abc.abstractmethod + def logit_mask(self, index: jnp.ndarray) -> Float[Array, "BS V"]: + """Returns logit mask at index.""" diff --git a/optformer/inference/sequence_utils.py b/optformer/inference/sequence_utils.py new file mode 100644 index 0000000..7f00c64 --- /dev/null +++ b/optformer/inference/sequence_utils.py @@ -0,0 +1,283 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for manipulating sequences. + +Unless otherwise stated, sequence-manipulation functions below should be assumed +to parallelize over the final axis (interpreted as a sequence). +""" + +from typing import Dict, Sequence, Union +import jax +from jax.experimental import checkify +import jax.numpy as jnp +from optformer.validation import checkify as _checkify + + +def count_not_from( + seq: jnp.ndarray, not_from: Sequence[int] = (0, 1) +) -> jnp.ndarray: + """Counts the number of elements which are NOT part of `not_from`. + + Useful for collecting the initial index of a token sequence. + + Args: + seq: Token (int) sequence to be filter-counted. Shape [..., L]. + not_from: Token IDs to ignore. Defaulted to (BOS, EOS) token IDs. + + Returns: + Filtered count. Last axis is reduce summed. Shape [...]. + """ + + where_cond = False + for ignore_int in not_from: + where_cond = (seq == ignore_int) | where_cond + return jnp.where(where_cond, 0, 1).sum(axis=-1) + + +def reduce_eq(ind: jnp.ndarray) -> jnp.ndarray: + """Validates if the last axis contains all equal values and reduces. + + e.g. [[1, 1, 1, 1], [2, 2, 2, 2]] -> [1, 2] + + Useful for reducing index tensors which have repeated values. + + Args: + ind: Possible (int) sequence to be reduced. Shape [..., S]. + + Returns: + Reduced indices of shape [...] if the final axis has repeated values. + Otherwise raises checkify error. + """ + if _checkify.enabled(): + all_same = jnp.all(ind == jnp.expand_dims(ind[..., 0], -1), axis=-1) + checkify.check( + jnp.all(ind == jnp.expand_dims(ind[..., 0], -1)), + msg=( + '`seq` must have repeated values. Offending sequence: ' + f'{ind[jnp.argmin(all_same)]}' + ), + ) + return ind[..., 0] + + +def shift_right( + seq: jnp.ndarray, insert_left: Sequence[int] = (0,) +) -> jnp.ndarray: + """Shifts sequence to the right, and inserts new tokens on the left. + + Useful for taking the output of model `[x,y,z,0,0,...]` and turning it + back into a proper input `[0,x,y,z,0,...]`. + + Args: + seq: Token (int) sequence to be filter-counted. Shape [..., L]. + insert_left: Token IDs to insert on the left. Defaulted to BOS token. + + Returns: + Shifted sequence. Shape [..., L]. + """ + + shifted_seq = jnp.roll(seq, shift=len(insert_left), axis=-1) + return shifted_seq.at[..., 0 : len(insert_left)].set(insert_left) + + +def broadcast_batch( + batch: Dict[str, jnp.ndarray], sizes: Sequence[int] +) -> Dict[str, jnp.ndarray]: + """Broadcasts all arrays in a batch. + + Args: + batch: Dictionary of sequences. Shape [...]. + sizes: New axes introduced. + + Returns: + Batch with newly broadcasted elements. Shape `sizes + [...]`. + """ + + return {k: jax.lax.broadcast(v, sizes) for k, v in batch.items()} + + +def find(seq: jnp.ndarray, elem: int, *, not_found: int = -1) -> jnp.ndarray: + """Finds first occurrence index of `elem` in a sequence. + + Args: + seq: Token (int) sequence. Shape [..., L]. + elem: Element value to find. + not_found: Value to return if elem is not found. Defaulted to -1, to output + a "special value" token and no-op common use-cases zero-ing all elements + to the right of the index. + + Returns: + First token index whose value is `elem`, else `not_found`. Shape [...]. + """ + + bool_arr = jnp.where(seq == elem, 1, 0) + maybe_index = jnp.argmax(bool_arr, axis=-1) + + not_found_cond = jnp.sum(bool_arr, axis=-1) == 0 + return jnp.where(not_found_cond, not_found, maybe_index) + + +def rfind(seq: jnp.ndarray, elem: int, *, not_found: int = -1) -> jnp.ndarray: + """Same format as `find`, but for last occurrence. + + Useful for finding the last location of a special token (e.g. separator + token). + + Args: + seq: See `find` + elem: See `find` + not_found: See `find` + + Returns: + Last token index whose value is `elem`, else `not_found`. Shape [...]. + """ + + bool_arr = jnp.where(seq == elem, 1, 0) + flipped_bool_arr = jnp.flip(bool_arr, axis=-1) + maybe_index = bool_arr.shape[-1] - 1 - jnp.argmax(flipped_bool_arr, axis=-1) + + not_found_cond = jnp.sum(bool_arr, axis=-1) == 0 + return jnp.where(not_found_cond, not_found, maybe_index) + + +def append_to_output( + seq: jnp.ndarray, elems: Sequence[int], *, bos: int = 0 +) -> jnp.ndarray: + """Appends elems to a decoding output sequence. + + Exact location starts from first occurrence of the BOS token. Useful for + appending values to decoder output sequences. + + Args: + seq: Token (int) sequence. Shape [..., L]. + elems: Elements to append. + bos: BOS token ID to determine initial index to append. + + Returns: + Sequence w/ elements appended, or no-op if appending will overwrite non-bos + tokens. Shape [..., L]. + """ + # TODO: Raise error if `seq` doesn't look lke a decode output. + # TODO: Implement. + raise NotImplementedError + + +def dynamic_slice_broadcast( + operand: jax.Array, slice_indices: jax.Array, slice_size: int +) -> jax.Array: + """Broadcasting version of jax.lax.dynamic_slice_in_dim.""" + fn = jax.lax.dynamic_slice_in_dim + for i in range(operand.ndim - slice_indices.ndim - 1): + fn = jax.vmap(fn, in_axes=(i, None, None), out_axes=i) + for i in range(slice_indices.ndim): + fn = jax.vmap( + fn, + in_axes=(i + operand.ndim - slice_indices.ndim - 1, i, None), + out_axes=i + operand.ndim - slice_indices.ndim - 1, + ) + return fn(operand, slice_indices, slice_size) + + +def rpad(seq: jnp.ndarray, target: jnp.ndarray) -> jnp.ndarray: + """Right-pads sequence with 0's to match w/ target sequence. + + Args: + seq: Token (int) sequence. Shape [..., L]. + target: Token (int) sequence to match on the inner dimension. Shape [..., + L'] where the outer dimensions can be different from seq's. + + Returns: + Padded sequence. Shape [..., L']. + """ + noop_paddings = [(0, 0) for _ in range(len(seq.shape) - 1)] + paddings = noop_paddings + [(0, target.shape[-1] - seq.shape[-1])] + return jnp.pad(seq, paddings, 'constant') + + +def slice_update( + seq: jnp.ndarray, start: Union[int, jnp.ndarray], elems: Sequence[int] +) -> jnp.ndarray: + """Jittable version of `seq[..., start:start+len(elems)].set(elems)`.""" + # TODO: Finish case when `start` is non-scalar. + for i, elem in enumerate(elems): + seq = seq.at[..., start + i].set(elem) + return seq + + +def value_mask(seq: jnp.ndarray, masked_values: Sequence[int]) -> jnp.ndarray: + """Computes value-matched mask from sequence. + + Ex: If `masked_values` are * and |, then: + seq: [*, |, x, y] + mask: [0, 0, 1, 1] + + Args: + seq: Token (int) sequence. Shape [..., L]. + masked_values: Values to mask out. + + Returns: + Mask of shape [..., L]. + """ + + mask = jnp.full(seq.shape, True, dtype=bool) + + for v in masked_values: + mask = jnp.logical_and(mask, jnp.not_equal(seq, v)) + + return mask + + +# pyformat: disable +def between_mask(seq: jnp.ndarray, left: int, right: int) -> jnp.ndarray: + """Computes the mask for a sequence given delimiters. + + Ex: If left/right delimiters are '*' and '|', then + seq: [*, w, x, y, |, *, z, |] + mask: [0, 1, 1, 1, 0, 0, 1, 0] + + Args: + seq: Token (int) sequence. Shape [..., L]. + left: Left delimiter. + right: Right delimiter. + + Returns: + Mask of shape [..., L]. + """ + + left_match = jnp.equal(seq, left) + right_match = jnp.equal(seq, right) + + if _checkify.enabled(): + # Check if count(left) == count(right) + left_count = jnp.sum(left_match, axis=-1) + right_count = jnp.sum(right_match, axis=-1) + eq_count = jnp.all(left_count == right_count) + checkify.check(eq_count, '`seq` has imbalanced delimiters.') + + # If our example tensor is [x, *, y, |], then example outputs are commented: + left_cs = jnp.cumsum(left_match, axis=-1) # [0, 1, 1, 1] + right_cs = jnp.cumsum(right_match, axis=-1) # [0, 0, 0, 1] + left_cs_slice = left_cs[..., :-1] # [0, 1, 1] + zeros = jnp.zeros(shape=list(left_cs_slice.shape[:-1]) + [1], dtype=jnp.int32) + shifted_left_cs = jnp.concatenate((zeros, left_cs_slice), axis=-1) # [0, 0, 1, 1] # pylint: disable=line-too-long + mask = shifted_left_cs - right_cs # [0, 0, 1, 0] + + if _checkify.enabled(): + # Check if there are no -1's (from wrong right -> left orderings). + all_ones_and_zeros = jnp.all((mask == 0) | (mask == 1)) + checkify.check(all_ones_and_zeros, '`seq` has imbalanced delimiters.') + + return mask.astype(jnp.bool_) +# pyformat: enable diff --git a/optformer/inference/sequence_utils_test.py b/optformer/inference/sequence_utils_test.py new file mode 100644 index 0000000..241ca94 --- /dev/null +++ b/optformer/inference/sequence_utils_test.py @@ -0,0 +1,225 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +from jax.experimental import checkify +import jax.numpy as jnp +import numpy as np +from optformer.inference import sequence_utils +from optformer.validation import checkify as _checkify +from absl.testing import absltest +from absl.testing import parameterized + + +class SequenceUtilsTest(parameterized.TestCase): + + def test_count_not_from(self): + x = jnp.array([[0, 4, 5, 1, 0, 0], [0, 5, 7, 8, 1, 0]]) + expected = [2, 3] + + jit_count_not_from = jax.jit(sequence_utils.count_not_from) + out = jit_count_not_from(x, not_from=(0, 1)) + np.testing.assert_array_equal(expected, out) + + def test_shift_right(self): + x = jnp.array([[1, 2, 3, 0, 0], [0, 0, 3, 4, 5]]) + expected = [[42, 100, 1, 2, 3], [42, 100, 0, 0, 3]] + + jit_shift_right = jax.jit(sequence_utils.shift_right) + out = jit_shift_right(x, insert_left=(42, 100)) + np.testing.assert_array_equal(expected, out) + + def test_find(self): + x = jnp.array([[0, 0, 42, 42, 42], [42, 42, 0, 0, 0], [0, 0, 0, 0, 0]]) + expected = [2, 0, -1] + + jit_find = jax.jit(sequence_utils.find) + out = jit_find(x, elem=42, not_found=-1) + np.testing.assert_array_equal(expected, out) + + def test_rfind(self): + x = jnp.array([[42, 42, 42, 0, 0], [0, 0, 0, 42, 42], [0, 0, 0, 0, 0]]) + expected = [2, 4, -1] + + jit_rfind = jax.jit(sequence_utils.rfind) + out = jit_rfind(x, elem=42, not_found=-1) + np.testing.assert_array_equal(expected, out) + + @absltest.skip("This might require dynamic_update?") + def test_append_to_output(self): + x = jnp.array([[1, 2, 0, 0], [1, 2, 3, 0], [0, 0, 0, 0]]) + expected = jnp.array([[1, 2, 42, 43], [1, 2, 3, 0], [42, 43, 0, 0]]) + + jit_append_to_output = jax.jit(sequence_utils.append_to_output) + out = jit_append_to_output(x, elems=[42, 43]) + np.testing.assert_array_equal(expected, out) + + def test_rpad(self): + x = jnp.array([[1, 2], [3, 4]]) + target = jnp.ones([3, 3]) + expected = jnp.array([[1, 2, 0], [3, 4, 0]]) + + jit_rpad = jax.jit(sequence_utils.rpad) + out = jit_rpad(x, target) + np.testing.assert_array_equal(expected, out) + + def test_value_mask(self): + x = jnp.array([[1, 2, 3], [4, 5, 6]]) + + jit_value_mask = jax.jit(sequence_utils.value_mask) + out = jit_value_mask(x, masked_values=[2, 3]) + expected = jnp.array([[1, 0, 0], [1, 1, 1]]) + np.testing.assert_array_equal(expected, out) + + def test_slice_update(self): + x = jnp.array([[1, 2, 3], [4, 5, 6]]) + jit_slice_update = jax.jit(sequence_utils.slice_update) + + out = jit_slice_update(x, start=0, elems=[42, 43]) + expected = jnp.array([[42, 43, 3], [42, 43, 6]]) + np.testing.assert_array_equal(expected, out) + + +class ReduceEqTest(parameterized.TestCase): + """Tests for `reduce_eq`.""" + + def setUp(self): + super().setUp() + self._reduce_eq = _checkify.check_and_jit(sequence_utils.reduce_eq) + _checkify.enable_checks(True) + + def tearDown(self): + super().tearDown() + _checkify.enable_checks(False) + + def test_raise_error(self): + x = jnp.array([1, 1, 1, 1, 1, 2]) + with self.assertRaises(checkify.JaxRuntimeError): + _ = self._reduce_eq(x) + + def test_good_array(self): + x = jnp.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) + out = self._reduce_eq(x) + expected = jnp.array([1, 2, 3]) + np.testing.assert_array_equal(expected, out) + + +class BetweenMaskTest(absltest.TestCase): + + def test_good_input(self): + x = jnp.array([[-1, 42, 42, 1, -1, 42, 1], [42, 42, 42, 42, 42, 42, 42]]) + jit_between_mask = _checkify.check_and_jit(sequence_utils.between_mask) + + out = jit_between_mask(x, left=-1, right=1) + expected = jnp.array([[0, 1, 1, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0]]) + np.testing.assert_array_equal(expected, out) + + def test_checkify_imbalanced_delimiters(self): + _checkify.enable_checks(True) + jit_between_mask = _checkify.check_and_jit(sequence_utils.between_mask) + x = jnp.array([[-1, 42, 42, 1, -1, 42, 0], [42, 42, 42, 42, 42, 42, 42]]) + with self.assertRaises(checkify.JaxRuntimeError): + _ = jit_between_mask(x, left=-1, right=1) + _checkify.enable_checks(False) + + def test_checkify_misordered_delimiters(self): + _checkify.enable_checks(True) + jit_between_mask = _checkify.check_and_jit(sequence_utils.between_mask) + x = jnp.array([[-1, 42, 42, 1, 1, 42, -1], [42, 42, 42, 42, 42, 42, 42]]) + with self.assertRaises(checkify.JaxRuntimeError): + _ = jit_between_mask(x, left=-1, right=1) + _checkify.enable_checks(False) + + +def _create_readable_array(shape) -> jax.Array: + """Returns an array where array[i,j,k] = ijk.""" + x = np.zeros(shape) + for e, k in enumerate(reversed(shape)): + if k >= 10: + raise ValueError("Cannot handle digits >= 10.") + for i in range(k): + index = [slice(None)] * len(x.shape) + index[x.ndim - e - 1] = slice(i, i + 1) + x[tuple(index)] += (i) * 10**e + + return jnp.asarray(x, dtype=jnp.int32) + + +class DynamicSliceBroadcastTest(parameterized.TestCase): + + def test_basic_1d(self): + actual = sequence_utils.dynamic_slice_broadcast( + jnp.array([0, 1, 2, 3]), jnp.array(1), 3 + ) + np.testing.assert_array_equal([1, 2, 3], actual) + + def test_x3d_index1d(self): + x = _create_readable_array([2, 3, 5]) + fn = jax.jit( + sequence_utils.dynamic_slice_broadcast, + static_argnums=[2], + ) + actual = fn(x, jnp.array([0, 1, 1]), 2) + expected = jnp.array( + [ + # 0, 1, 1 -> we should see [0, 1], [1, 2], [1, 2] + # as the last digit. + [[0, 1], [1, 2], [1, 2]], + [[0, 1], [1, 2], [1, 2]], + ], + ) + self.assertSequenceEqual(actual.shape, (2, 3, 2)) + np.testing.assert_array_equal(expected, actual % 10) + + def test_x4d_index1d(self): + x = _create_readable_array([5, 2, 3, 4]) + fn = jax.jit( + sequence_utils.dynamic_slice_broadcast, + static_argnums=[2], + ) + actual = fn(x, jnp.array([0, 1, 1]), 2) + self.assertSequenceEqual(actual.shape, (5, 2, 3, 2)) + expected = jnp.array( + [ + # 0, 1, 1 -> we should see [0, 1], [1, 2], [1, 2] as + # the last digit. + [[[0, 1], [1, 2], [1, 2]], [[0, 1], [1, 2], [1, 2]]], + [[[0, 1], [1, 2], [1, 2]], [[0, 1], [1, 2], [1, 2]]], + [[[0, 1], [1, 2], [1, 2]], [[0, 1], [1, 2], [1, 2]]], + [[[0, 1], [1, 2], [1, 2]], [[0, 1], [1, 2], [1, 2]]], + [[[0, 1], [1, 2], [1, 2]], [[0, 1], [1, 2], [1, 2]]], + ], + ) + np.testing.assert_array_equal(expected, actual % 10) + + def test_x4d_index2d(self): + x = _create_readable_array([5, 2, 3, 4]) + fn = jax.jit( + sequence_utils.dynamic_slice_broadcast, + static_argnums=[2], + ) + actual = fn(x, jnp.array([[0, 1, 1], [1, 2, 1]]), 2) + self.assertSequenceEqual(actual.shape, (5, 2, 3, 2)) + expected = jnp.array([ + [[[0, 1], [1, 2], [1, 2]], [[1, 2], [2, 3], [1, 2]]], + [[[0, 1], [1, 2], [1, 2]], [[1, 2], [2, 3], [1, 2]]], + [[[0, 1], [1, 2], [1, 2]], [[1, 2], [2, 3], [1, 2]]], + [[[0, 1], [1, 2], [1, 2]], [[1, 2], [2, 3], [1, 2]]], + [[[0, 1], [1, 2], [1, 2]], [[1, 2], [2, 3], [1, 2]]], + ]) + np.testing.assert_array_equal(expected, actual % 10) + + +if __name__ == "__main__": + absltest.main() diff --git a/optformer/validation/checkify.py b/optformer/validation/checkify.py new file mode 100644 index 0000000..d90b357 --- /dev/null +++ b/optformer/validation/checkify.py @@ -0,0 +1,54 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkify module.""" + +from typing import Callable, TypeVar +import jax +from jax.experimental import checkify + + +_ENABLED: bool = False # global variable + + +def enable_checks(enable: bool) -> None: + global _ENABLED + _ENABLED = enable + + +def enabled() -> bool: + return _ENABLED + + +_R = TypeVar('_R') + + +def _jittable_inner( + fn: Callable[..., _R], *args, **kwargs +) -> tuple[checkify.Error, _R]: + """This function is used to avoid retracing.""" + return checkify.checkify(fn)(*args, **kwargs) + + +def check_and_jit(fn: Callable[..., _R]) -> Callable[..., _R]: + """Throws checkify errors while preserving the function signature.""" + + def inner(*args, **kwargs) -> _R: + err, result = jax.jit(_jittable_inner, static_argnums=[0])( + fn, *args, **kwargs + ) + err.throw() + return result + + return inner diff --git a/optformer/validation/checkify_test.py b/optformer/validation/checkify_test.py new file mode 100644 index 0000000..64ac2bd --- /dev/null +++ b/optformer/validation/checkify_test.py @@ -0,0 +1,36 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from optformer.validation import checkify as _checkify +from absl.testing import absltest + + +class CheckifyTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.trace_count = 0 + + def test_traces_only_once(self): + def f(x): + self.trace_count += 1 + return x + 1 + + _checkify.check_and_jit(f)(0.1) + _checkify.check_and_jit(f)(0.1) + self.assertEqual(1, self.trace_count) + + +if __name__ == "__main__": + absltest.main()