diff --git a/chirp/models/hubert.py b/chirp/models/hubert.py index 1e1dcfbf..6fd8b1ca 100644 --- a/chirp/models/hubert.py +++ b/chirp/models/hubert.py @@ -77,7 +77,7 @@ def compute_mask_indices( num_mask = mask_prob * sz / jnp.array(mask_length, float) + rounding_offset num_mask = jnp.full(bsz, num_mask).astype(int) max_masks = sz - mask_length + 1 - num_mask = jnp.clip(num_mask, a_min=min_masks, a_max=max_masks) + num_mask = jnp.clip(num_mask, min=min_masks, max=max_masks) # First, sample a set of start indices for the max possible number of masks. # Do this sampling separately for each batch sample, to allow `replace`=False.