Skip to content

Commit

Permalink
jax.numpy.clip: update use of deprecated arguments.
Browse files Browse the repository at this point in the history
- a is now positional-only
- a_min is now min
- a_max is now max

The old argument names have been deprecated since JAX v0.4.27.

PiperOrigin-RevId: 715321661
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Jan 14, 2025
1 parent c05cf3a commit ac6e047
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion chirp/models/hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def compute_mask_indices(
num_mask = mask_prob * sz / jnp.array(mask_length, float) + rounding_offset
num_mask = jnp.full(bsz, num_mask).astype(int)
max_masks = sz - mask_length + 1
num_mask = jnp.clip(num_mask, a_min=min_masks, a_max=max_masks)
num_mask = jnp.clip(num_mask, min=min_masks, max=max_masks)

# First, sample a set of start indices for the max possible number of masks.
# Do this sampling separately for each batch sample, to allow `replace`=False.
Expand Down

0 comments on commit ac6e047

Please sign in to comment.