Skip to content

Commit 535ae42

Browse files
committed
Add one_hot
1 parent 0979f26 commit 535ae42

File tree

5 files changed

+201
-5
lines changed

5 files changed

+201
-5
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
expand_dims
1616
isclose
1717
kron
18+
one_hot
1819
nunique
1920
pad
2021
setdiff1d

src/array_api_extra/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import isclose, pad
3+
from ._delegation import isclose, one_hot, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
66
apply_where,
@@ -32,6 +32,7 @@
3232
"kron",
3333
"lazy_apply",
3434
"nunique",
35+
"one_hot",
3536
"pad",
3637
"setdiff1d",
3738
"sinc",

src/array_api_extra/_delegation.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@
99
array_namespace,
1010
is_cupy_namespace,
1111
is_dask_namespace,
12+
is_jax_array,
1213
is_jax_namespace,
1314
is_numpy_namespace,
1415
is_pydata_sparse_namespace,
16+
is_torch_array,
1517
is_torch_namespace,
1618
)
1719
from ._lib._utils._helpers import asarrays
18-
from ._lib._utils._typing import Array
20+
from ._lib._utils._typing import Array, DType
1921

20-
__all__ = ["isclose", "pad"]
22+
__all__ = ["isclose", "one_hot", "pad"]
2123

2224

2325
def isclose(
@@ -112,6 +114,90 @@ def isclose(
112114
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
113115

114116

117+
def one_hot(
118+
x: Array,
119+
/,
120+
num_classes: int,
121+
*,
122+
dtype: DType | None = None,
123+
axis: int = -1,
124+
xp: ModuleType | None = None,
125+
) -> Array:
126+
"""
127+
One-hot encode the given indices.
128+
129+
Each index in the input ``x`` is encoded as a vector of zeros of length
130+
``num_classes`` with the element at the given index set to one.
131+
132+
Parameters
133+
----------
134+
x : array
135+
An array with integral dtype having shape ``batch_dims``.
136+
num_classes : int
137+
Number of classes in the one-hot dimension.
138+
dtype : DType, optional
139+
The dtype of the return value. Defaults to the default float dtype (usually
140+
float64).
141+
axis : int or tuple of ints, optional
142+
Position(s) in the expanded axes where the new axis is placed.
143+
xp : array_namespace, optional
144+
The standard-compatible namespace for `x`. Default: infer.
145+
146+
Returns
147+
-------
148+
array
149+
An array having the same shape as `x` except for a new axis at the position
150+
given by `axis` having size `num_classes`.
151+
152+
If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
153+
an exception, or may even cause a bad state. `x` is not checked.
154+
155+
Examples
156+
--------
157+
>>> xp.one_hot(jnp.array([1, 2, 0]), 3)
158+
Array([[0., 1., 0.],
159+
[0., 0., 1.],
160+
[1., 0., 0.]], dtype=float64)
161+
"""
162+
# Validate inputs.
163+
if xp is None:
164+
xp = array_namespace(x)
165+
if not xp.isdtype(x.dtype, "integral"):
166+
msg = "x must have an integral dtype."
167+
raise TypeError(msg)
168+
if dtype is None:
169+
dtype = xp.empty(()).dtype # Default float dtype
170+
# Delegate where possible.
171+
if is_jax_namespace(xp):
172+
assert is_jax_array(x)
173+
from jax.nn import one_hot as jax_one_hot
174+
175+
return jax_one_hot(x, num_classes, dtype=dtype, axis=axis)
176+
if is_torch_namespace(xp):
177+
assert is_torch_array(x)
178+
from torch.nn.functional import one_hot as torch_one_hot
179+
180+
x = xp.astype(x, xp.int64) # PyTorch only supports int64 here.
181+
try:
182+
out = torch_one_hot(x, num_classes)
183+
except RuntimeError as e:
184+
raise IndexError from e
185+
out = xp.astype(out, dtype)
186+
else:
187+
out = _funcs.one_hot(
188+
x,
189+
num_classes,
190+
dtype=dtype,
191+
xp=xp,
192+
supports_fancy_indexing=is_numpy_namespace(xp),
193+
supports_array_indexing=is_dask_namespace(xp),
194+
)
195+
196+
if axis != -1:
197+
out = xp.moveaxis(out, -1, axis)
198+
return out
199+
200+
115201
def pad(
116202
x: Array,
117203
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],

src/array_api_extra/_lib/_funcs.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@
88

99
from ._at import at
1010
from ._utils import _compat, _helpers
11-
from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array
11+
from ._utils._compat import (
12+
array_namespace,
13+
is_dask_namespace,
14+
is_jax_array,
15+
)
1216
from ._utils._helpers import (
1317
asarrays,
1418
capabilities,
1519
eager_shape,
1620
meta_namespace,
1721
ndindex,
1822
)
19-
from ._utils._typing import Array
23+
from ._utils._typing import Array, DType
2024

2125
__all__ = [
2226
"apply_where",
@@ -375,6 +379,36 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
375379
return xp.squeeze(c, axis=axes)
376380

377381

382+
def one_hot(
383+
x: Array,
384+
/,
385+
num_classes: int,
386+
*,
387+
supports_fancy_indexing: bool = False,
388+
supports_array_indexing: bool = False,
389+
dtype: DType,
390+
xp: ModuleType,
391+
) -> Array: # numpydoc ignore=PR01,RT01
392+
"""See docstring in `array_api_extra._delegation.py`."""
393+
x_size = x.size
394+
if x_size is None: # pragma: no cover
395+
msg = "x must have a concrete size."
396+
raise TypeError(msg)
397+
out = xp.zeros((x.size, num_classes), dtype=dtype)
398+
x_flattened = xp.reshape(x, (-1,))
399+
if supports_fancy_indexing:
400+
out = at(out)[xp.arange(x_size), x_flattened].set(1)
401+
else:
402+
for i in range(x_size):
403+
x_i = x_flattened[i]
404+
if not supports_array_indexing:
405+
x_i = int(x_i)
406+
out = at(out)[i, x_i].set(1)
407+
if x.ndim != 1:
408+
out = xp.reshape(out, (*x.shape, num_classes))
409+
return out
410+
411+
378412
def create_diagonal(
379413
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
380414
) -> Array:

tests/test_funcs.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
isclose,
2222
kron,
2323
nunique,
24+
one_hot,
2425
pad,
2526
setdiff1d,
2627
sinc,
@@ -44,6 +45,7 @@
4445
lazy_xp_function(expand_dims)
4546
lazy_xp_function(kron)
4647
lazy_xp_function(nunique)
48+
lazy_xp_function(one_hot)
4749
lazy_xp_function(pad)
4850
# FIXME calls in1d which calls xp.unique_values without size
4951
lazy_xp_function(setdiff1d, jax_jit=False)
@@ -448,6 +450,78 @@ def test_xp(self, xp: ModuleType):
448450
)
449451

450452

453+
@pytest.mark.skip_xp_backend(
454+
Backend.SPARSE, reason="read-only backend without .at support"
455+
)
456+
@pytest.mark.skip_xp_backend(
457+
Backend.DASK, reason="backend does not yet support indexed assignment"
458+
)
459+
class TestOneHot:
460+
@pytest.mark.parametrize("n_dim", range(4))
461+
@pytest.mark.parametrize("num_classes", [1, 3, 10])
462+
def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
463+
shape = tuple(range(2, 2 + n_dim))
464+
rng = np.random.default_rng(2347823)
465+
np_x = rng.integers(num_classes, size=shape)
466+
x = xp.asarray(np_x)
467+
y = one_hot(x, num_classes)
468+
assert y.shape == (*x.shape, num_classes)
469+
for *i_list, j in ndindex(*shape, num_classes):
470+
i = tuple(i_list)
471+
assert float(y[(*i, j)]) == (int(x[i]) == j)
472+
473+
def test_basic(self, xp: ModuleType):
474+
actual = one_hot(xp.asarray([0, 1, 2]), 3)
475+
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
476+
xp_assert_equal(actual, expected)
477+
478+
actual = one_hot(xp.asarray([1, 2, 0]), 3)
479+
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
480+
xp_assert_equal(actual, expected)
481+
482+
@pytest.mark.skip_xp_backend(
483+
Backend.TORCH_GPU, reason="Puts Pytorch into a bad state."
484+
)
485+
def test_out_of_bound(self, xp: ModuleType):
486+
# Undefined behavior. Either return zero, or raise.
487+
try:
488+
actual = one_hot(xp.asarray([-1, 3]), 3)
489+
except IndexError:
490+
return
491+
expected = xp.asarray([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
492+
xp_assert_equal(actual, expected)
493+
494+
@pytest.mark.parametrize(
495+
"int_dtype",
496+
["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"],
497+
)
498+
def test_int_types(self, xp: ModuleType, int_dtype: str):
499+
dtype = getattr(xp, int_dtype)
500+
x = xp.asarray([0, 1, 2], dtype=dtype)
501+
actual = one_hot(x, 3)
502+
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
503+
xp_assert_equal(actual, expected)
504+
505+
def test_custom_dtype(self, xp: ModuleType):
506+
actual = one_hot(xp.asarray([0, 1, 2], dtype=xp.int32), 3, dtype=xp.bool)
507+
expected = xp.asarray(
508+
[[True, False, False], [False, True, False], [False, False, True]]
509+
)
510+
xp_assert_equal(actual, expected)
511+
512+
def test_axis(self, xp: ModuleType):
513+
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]).T
514+
actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=0)
515+
xp_assert_equal(actual, expected)
516+
517+
actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=-2)
518+
xp_assert_equal(actual, expected)
519+
520+
def test_non_integer(self, xp: ModuleType):
521+
with pytest.raises(TypeError):
522+
_ = one_hot(xp.asarray([1.0]), 3)
523+
524+
451525
@pytest.mark.skip_xp_backend(
452526
Backend.SPARSE, reason="read-only backend without .at support"
453527
)

0 commit comments

Comments
 (0)