Skip to content

Commit fa0e4a4

Browse files
committed
Add one_hot
1 parent c1adc04 commit fa0e4a4

File tree

5 files changed

+194
-3
lines changed

5 files changed

+194
-3
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
isclose
1818
kron
1919
nunique
20+
one_hot
2021
pad
2122
setdiff1d
2223
sinc

src/array_api_extra/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import isclose, pad
3+
from ._delegation import isclose, one_hot, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
66
apply_where,
@@ -34,6 +34,7 @@
3434
"kron",
3535
"lazy_apply",
3636
"nunique",
37+
"one_hot",
3738
"pad",
3839
"setdiff1d",
3940
"sinc",

src/array_api_extra/_delegation.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
is_pydata_sparse_namespace,
1515
is_torch_namespace,
1616
)
17+
from ._lib._utils._compat import device as get_device
1718
from ._lib._utils._helpers import asarrays
18-
from ._lib._utils._typing import Array
19+
from ._lib._utils._typing import Array, DType
1920

20-
__all__ = ["isclose", "pad"]
21+
__all__ = ["isclose", "one_hot", "pad"]
2122

2223

2324
def isclose(
@@ -112,6 +113,83 @@ def isclose(
112113
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
113114

114115

116+
def one_hot(
117+
x: Array,
118+
/,
119+
num_classes: int,
120+
*,
121+
dtype: DType | None = None,
122+
axis: int = -1,
123+
xp: ModuleType | None = None,
124+
) -> Array:
125+
"""
126+
One-hot encode the given indices.
127+
128+
Each index in the input `x` is encoded as a vector of zeros of length `num_classes`
129+
with the element at the given index set to one.
130+
131+
Parameters
132+
----------
133+
x : array
134+
An array with integral dtype whose values are between `0` and `num_classes - 1`.
135+
num_classes : int
136+
Number of classes in the one-hot dimension.
137+
dtype : DType, optional
138+
The dtype of the return value. Defaults to the default float dtype (usually
139+
float64).
140+
axis : int, optional
141+
Position in the expanded axes where the new axis is placed. Default: -1.
142+
xp : array_namespace, optional
143+
The standard-compatible namespace for `x`. Default: infer.
144+
145+
Returns
146+
-------
147+
array
148+
An array having the same shape as `x` except for a new axis at the position
149+
given by `axis` having size `num_classes`. If `axis` is unspecified, it
150+
defaults to -1, which appends a new axis.
151+
152+
If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
153+
an exception, or may even cause a bad state. `x` is not checked.
154+
155+
Examples
156+
--------
157+
>>> import array_api_extra as xpx
158+
>>> import array_api_strict as xp
159+
>>> xpx.one_hot(xp.asarray([1, 2, 0]), 3)
160+
Array([[0., 1., 0.],
161+
[0., 0., 1.],
162+
[1., 0., 0.]], dtype=array_api_strict.float64)
163+
"""
164+
# Validate inputs.
165+
if xp is None:
166+
xp = array_namespace(x)
167+
if not xp.isdtype(x.dtype, "integral"):
168+
msg = "x must have an integral dtype."
169+
raise TypeError(msg)
170+
if dtype is None:
171+
dtype = _funcs.default_dtype(xp, device=get_device(x))
172+
# Delegate where possible.
173+
if is_jax_namespace(xp):
174+
from jax.nn import one_hot as jax_one_hot
175+
176+
return jax_one_hot(x, num_classes, dtype=dtype, axis=axis)
177+
if is_torch_namespace(xp):
178+
from torch.nn.functional import one_hot as torch_one_hot
179+
180+
x = xp.astype(x, xp.int64) # PyTorch only supports int64 here.
181+
try:
182+
out = torch_one_hot(x, num_classes)
183+
except RuntimeError as e:
184+
raise IndexError from e
185+
else:
186+
out = _funcs.one_hot(x, num_classes, xp=xp)
187+
out = xp.astype(out, dtype, copy=False)
188+
if axis != -1:
189+
out = xp.moveaxis(out, -1, axis)
190+
return out
191+
192+
115193
def pad(
116194
x: Array,
117195
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],

src/array_api_extra/_lib/_funcs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,23 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
375375
return xp.squeeze(c, axis=axes)
376376

377377

378+
def one_hot(
379+
x: Array,
380+
/,
381+
num_classes: int,
382+
*,
383+
xp: ModuleType,
384+
) -> Array: # numpydoc ignore=PR01,RT01
385+
"""See docstring in `array_api_extra._delegation.py`."""
386+
# TODO: Benchmark whether this is faster on the NumPy backend:
387+
# if is_numpy_array(x):
388+
# out = xp.zeros((x.size, num_classes), dtype=dtype)
389+
# out[xp.arange(x.size), xp.reshape(x, (-1,))] = 1
390+
# return xp.reshape(out, (*x.shape, num_classes))
391+
range_num_classes = xp.arange(num_classes, dtype=x.dtype, device=_compat.device(x))
392+
return x[..., xp.newaxis] == range_num_classes
393+
394+
378395
def create_diagonal(
379396
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
380397
) -> Array:

tests/test_funcs.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
isclose,
2323
kron,
2424
nunique,
25+
one_hot,
2526
pad,
2627
setdiff1d,
2728
sinc,
@@ -45,6 +46,7 @@
4546
lazy_xp_function(expand_dims)
4647
lazy_xp_function(kron)
4748
lazy_xp_function(nunique)
49+
lazy_xp_function(one_hot)
4850
lazy_xp_function(pad)
4951
# FIXME calls in1d which calls xp.unique_values without size
5052
lazy_xp_function(setdiff1d, jax_jit=False)
@@ -449,6 +451,98 @@ def test_xp(self, xp: ModuleType):
449451
)
450452

451453

454+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="backend doesn't have arange")
455+
class TestOneHot:
456+
@pytest.mark.parametrize("n_dim", range(4))
457+
@pytest.mark.parametrize("num_classes", [1, 3, 10])
458+
def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
459+
shape = tuple(range(2, 2 + n_dim))
460+
rng = np.random.default_rng(2347823)
461+
np_x = rng.integers(num_classes, size=shape)
462+
x = xp.asarray(np_x)
463+
y = one_hot(x, num_classes)
464+
assert y.shape == (*x.shape, num_classes)
465+
for *i_list, j in ndindex(*shape, num_classes):
466+
i = tuple(i_list)
467+
assert float(y[(*i, j)]) == (int(x[i]) == j)
468+
469+
def test_basic(self, xp: ModuleType):
470+
actual = one_hot(xp.asarray([0, 1, 2]), 3)
471+
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
472+
xp_assert_equal(actual, expected)
473+
474+
actual = one_hot(xp.asarray([1, 2, 0]), 3)
475+
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
476+
xp_assert_equal(actual, expected)
477+
478+
def test_2d(self, xp: ModuleType):
479+
actual = one_hot(xp.asarray([[2, 1, 0], [1, 0, 2]]), 3, axis=1)
480+
expected = xp.asarray(
481+
[
482+
[[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]],
483+
[[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]],
484+
]
485+
)
486+
xp_assert_equal(actual, expected)
487+
488+
@pytest.mark.skip_xp_backend(
489+
Backend.ARRAY_API_STRICTEST, reason="backend doesn't support Boolean indexing"
490+
)
491+
def test_abstract_size(self, xp: ModuleType):
492+
x = xp.arange(5)
493+
x = x[x > 2]
494+
actual = one_hot(x, 5)
495+
expected = xp.asarray([[0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0]])
496+
xp_assert_equal(actual, expected)
497+
498+
@pytest.mark.skip_xp_backend(
499+
Backend.TORCH_GPU, reason="Puts Pytorch into a bad state."
500+
)
501+
def test_out_of_bound(self, xp: ModuleType):
502+
# Undefined behavior. Either return zero, or raise.
503+
try:
504+
actual = one_hot(xp.asarray([-1, 3]), 3)
505+
except IndexError:
506+
return
507+
expected = xp.asarray([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
508+
xp_assert_equal(actual, expected)
509+
510+
@pytest.mark.parametrize(
511+
"int_dtype",
512+
["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"],
513+
)
514+
def test_int_types(self, xp: ModuleType, int_dtype: str):
515+
dtype = getattr(xp, int_dtype)
516+
x = xp.asarray([0, 1, 2], dtype=dtype)
517+
actual = one_hot(x, 3)
518+
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
519+
xp_assert_equal(actual, expected)
520+
521+
def test_custom_dtype(self, xp: ModuleType):
522+
actual = one_hot(xp.asarray([0, 1, 2], dtype=xp.int32), 3, dtype=xp.bool)
523+
expected = xp.asarray(
524+
[[True, False, False], [False, True, False], [False, False, True]]
525+
)
526+
xp_assert_equal(actual, expected)
527+
528+
def test_axis(self, xp: ModuleType):
529+
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]).T
530+
actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=0)
531+
xp_assert_equal(actual, expected)
532+
533+
actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=-2)
534+
xp_assert_equal(actual, expected)
535+
536+
def test_non_integer(self, xp: ModuleType):
537+
with pytest.raises(TypeError):
538+
_ = one_hot(xp.asarray([1.0]), 3)
539+
540+
def test_device(self, xp: ModuleType, device: Device):
541+
x = xp.asarray([0, 1, 2], device=device)
542+
y = one_hot(x, 3)
543+
assert get_device(y) == device
544+
545+
452546
@pytest.mark.skip_xp_backend(
453547
Backend.SPARSE, reason="read-only backend without .at support"
454548
)

0 commit comments

Comments
 (0)