Skip to content

Commit

Permalink
Design 2->4
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 24, 2025
1 parent 6e3c824 commit 2900169
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ def setdiff1d(
/,
*,
assume_unique: bool = False,
size: int | None = None,
fill_value: object | None = None,
xp: ModuleType | None = None,
) -> Array:
Expand All @@ -561,11 +562,16 @@ def setdiff1d(
assume_unique : bool
If ``True``, the input arrays are both assumed to be unique, which
can speed up the calculation. Default is ``False``.
fill_value : object, optional
Pad the output array with this value.
size : int, optional
The size of the output array. This is exclusively used inside the JAX JIT, and
only for as long as JAX does not support arrays of unknown size inside it. In
all other cases, it is disregarded.
Returned elements will be clipped if they are more than size, and padded with
`fill_value` if they are less. Default: raise if inside ``jax.jit``.
This is exclusively used for JAX arrays when running inside ``jax.jit``,
where all array shapes need to be known in advance.
fill_value : object, optional
Pad the output array with this value. This is exclusively used for JAX arrays
when running inside ``jax.jit``. Default: 0.
xp : array_namespace, optional
The standard-compatible namespace for `x1` and `x2`. Default: infer.
Expand Down Expand Up @@ -630,7 +636,7 @@ def _dask_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
return x1 if assume_unique else xp.unique_values(x1)

Check warning on line 636 in src/array_api_extra/_lib/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_funcs.py#L636

Added line #L636 was not covered by tests

def _jax_jit_impl(
x1: Array, x2: Array, fill_value: object | None
x1: Array, x2: Array, size: int | None, fill_value: object | None
) -> Array: # numpydoc ignore=PR01,RT01
"""
JAX implementation inside jax.jit.
Expand All @@ -639,9 +645,9 @@ def _jax_jit_impl(
and not being able to filter by a boolean mask.
Returns array the same size as x1, padded with fill_value.
"""
# unique_values inside jax.jit is not supported unless it's got a fixed size
mask = _x1_not_in_x2(x1, x2)

if size is None:
msg = "`size` is mandatory when running inside `jax.jit`."
raise ValueError(msg)
if fill_value is None:
fill_value = xp.zeros((), dtype=x1.dtype)

Check warning on line 652 in src/array_api_extra/_lib/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_funcs.py#L648-L652

Added lines #L648 - L652 were not covered by tests
else:
Expand All @@ -650,9 +656,12 @@ def _jax_jit_impl(
msg = "`fill_value` must be a scalar."
raise ValueError(msg)

Check warning on line 657 in src/array_api_extra/_lib/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_funcs.py#L654-L657

Added lines #L654 - L657 were not covered by tests

# unique_values inside jax.jit is not supported unless it's got a fixed size
mask = _x1_not_in_x2(x1, x2)
x1 = xp.where(mask, x1, fill_value)

Check warning on line 661 in src/array_api_extra/_lib/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_funcs.py#L660-L661

Added lines #L660 - L661 were not covered by tests
# Note: jnp.unique_values sorts
return xp.unique_values(x1, size=x1.size, fill_value=fill_value)
# Move fill_value to the right
x1 = xp.take(x1, xp.argsort(~mask, stable=True))
x1 = xp.unique_values(x1, size=x1.size, fill_value=fill_value)

Check warning on line 664 in src/array_api_extra/_lib/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_funcs.py#L663-L664

Added lines #L663 - L664 were not covered by tests

if is_dask_namespace(xp):
return _dask_impl(x1, x2)

Check warning on line 667 in src/array_api_extra/_lib/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_funcs.py#L667

Added line #L667 was not covered by tests
Expand All @@ -666,7 +675,7 @@ def _jax_jit_impl(
jax.errors.ConcretizationTypeError,
jax.errors.NonConcreteBooleanIndexError,
):
return _jax_jit_impl(x1, x2, fill_value) # inside jax.jit
return _jax_jit_impl(x1, x2, size, fill_value) # inside jax.jit

Check warning on line 678 in src/array_api_extra/_lib/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/array_api_extra/_lib/_funcs.py#L678

Added line #L678 was not covered by tests

return _generic_impl(x1, x2)

Expand Down

0 comments on commit 2900169

Please sign in to comment.