Skip to content

Commit 60d3b3b

Browse files
committed
Add one_hot
1 parent 0fc862c commit 60d3b3b

File tree

5 files changed

+176
-3
lines changed

5 files changed

+176
-3
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
expand_dims
1616
isclose
1717
kron
18+
one_hot
1819
nunique
1920
pad
2021
setdiff1d

src/array_api_extra/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import isclose, pad
3+
from ._delegation import isclose, one_hot, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
66
apply_where,
@@ -32,6 +32,7 @@
3232
"kron",
3333
"lazy_apply",
3434
"nunique",
35+
"one_hot",
3536
"pad",
3637
"setdiff1d",
3738
"sinc",

src/array_api_extra/_delegation.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
is_pydata_sparse_namespace,
1515
is_torch_namespace,
1616
)
17+
from ._lib._utils._compat import device as get_device
1718
from ._lib._utils._helpers import asarrays
18-
from ._lib._utils._typing import Array
19+
from ._lib._utils._typing import Array, DType
1920

20-
__all__ = ["isclose", "pad"]
21+
__all__ = ["isclose", "one_hot", "pad"]
2122

2223

2324
def isclose(
@@ -112,6 +113,85 @@ def isclose(
112113
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
113114

114115

116+
def one_hot(
117+
x: Array,
118+
/,
119+
num_classes: int,
120+
*,
121+
dtype: DType | None = None,
122+
axis: int = -1,
123+
xp: ModuleType | None = None,
124+
) -> Array:
125+
"""
126+
One-hot encode the given indices.
127+
128+
Each index in the input `x` is encoded as a vector of zeros of length `num_classes`
129+
with the element at the given index set to one.
130+
131+
Parameters
132+
----------
133+
x : array
134+
An array with integral dtype whose values are between `0` and `num_classes - 1`.
135+
num_classes : int
136+
Number of classes in the one-hot dimension.
137+
dtype : DType, optional
138+
The dtype of the return value. Defaults to the default float dtype (usually
139+
float64).
140+
axis : int or tuple of ints, optional
141+
Position(s) in the expanded axes where the new axis is placed.
142+
xp : array_namespace, optional
143+
The standard-compatible namespace for `x`. Default: infer.
144+
145+
Returns
146+
-------
147+
array
148+
An array having the same shape as `x` except for a new axis at the position
149+
given by `axis` having size `num_classes`. If `axis` is unspecified, it
150+
defaults to -1, which appends a new axis.
151+
152+
If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
153+
an exception, or may even cause a bad state. `x` is not checked.
154+
155+
Examples
156+
--------
157+
>>> import array_api_extra as xpx
158+
>>> import array-api-strict as xp
159+
>>> xpx.one_hot(xp.asarray([1, 2, 0]), 3)
160+
Array([[0., 1., 0.],
161+
[0., 0., 1.],
162+
[1., 0., 0.]], dtype=array_api_strict.float64)
163+
"""
164+
# Validate inputs.
165+
if xp is None:
166+
xp = array_namespace(x)
167+
if not xp.isdtype(x.dtype, "integral"):
168+
msg = "x must have an integral dtype."
169+
raise TypeError(msg)
170+
if dtype is None:
171+
dtype = xp.__array_namespace_info__().default_dtypes(device=get_device(x))[
172+
"real floating"
173+
]
174+
# Delegate where possible.
175+
if is_jax_namespace(xp):
176+
from jax.nn import one_hot as jax_one_hot
177+
178+
return jax_one_hot(x, num_classes, dtype=dtype, axis=axis)
179+
if is_torch_namespace(xp):
180+
from torch.nn.functional import one_hot as torch_one_hot
181+
182+
x = xp.astype(x, xp.int64) # PyTorch only supports int64 here.
183+
try:
184+
out = torch_one_hot(x, num_classes)
185+
except RuntimeError as e:
186+
raise IndexError from e
187+
else:
188+
out = _funcs.one_hot(x, num_classes, xp=xp)
189+
out = xp.astype(out, dtype, copy=False)
190+
if axis != -1:
191+
out = xp.moveaxis(out, -1, axis)
192+
return out
193+
194+
115195
def pad(
116196
x: Array,
117197
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],

src/array_api_extra/_lib/_funcs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,23 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
375375
return xp.squeeze(c, axis=axes)
376376

377377

378+
def one_hot(
379+
x: Array,
380+
/,
381+
num_classes: int,
382+
*,
383+
xp: ModuleType,
384+
) -> Array: # numpydoc ignore=PR01,RT01
385+
"""See docstring in `array_api_extra._delegation.py`."""
386+
# TODO: Benchmark whether this is faster on the NumPy backend:
387+
# if is_numpy_array(x):
388+
# out = xp.zeros((x.size, num_classes), dtype=dtype)
389+
# out[xp.arange(x.size), xp.reshape(x, (-1,))] = 1
390+
# return xp.reshape(out, (*x.shape, num_classes))
391+
range_num_classes = xp.arange(num_classes, dtype=x.dtype, device=_compat.device(x))
392+
return x[..., xp.newaxis] == range_num_classes
393+
394+
378395
def create_diagonal(
379396
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
380397
) -> Array:

tests/test_funcs.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
isclose,
2222
kron,
2323
nunique,
24+
one_hot,
2425
pad,
2526
setdiff1d,
2627
sinc,
@@ -44,6 +45,7 @@
4445
lazy_xp_function(expand_dims)
4546
lazy_xp_function(kron)
4647
lazy_xp_function(nunique)
48+
lazy_xp_function(one_hot)
4749
lazy_xp_function(pad)
4850
# FIXME calls in1d which calls xp.unique_values without size
4951
lazy_xp_function(setdiff1d, jax_jit=False)
@@ -448,6 +450,78 @@ def test_xp(self, xp: ModuleType):
448450
)
449451

450452

453+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="backend doesn't have arange")
454+
class TestOneHot:
455+
@pytest.mark.parametrize("n_dim", range(4))
456+
@pytest.mark.parametrize("num_classes", [1, 3, 10])
457+
def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
458+
shape = tuple(range(2, 2 + n_dim))
459+
rng = np.random.default_rng(2347823)
460+
np_x = rng.integers(num_classes, size=shape)
461+
x = xp.asarray(np_x)
462+
y = one_hot(x, num_classes)
463+
assert y.shape == (*x.shape, num_classes)
464+
for *i_list, j in ndindex(*shape, num_classes):
465+
i = tuple(i_list)
466+
assert float(y[(*i, j)]) == (int(x[i]) == j)
467+
468+
def test_basic(self, xp: ModuleType):
469+
actual = one_hot(xp.asarray([0, 1, 2]), 3)
470+
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
471+
xp_assert_equal(actual, expected)
472+
473+
actual = one_hot(xp.asarray([1, 2, 0]), 3)
474+
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
475+
xp_assert_equal(actual, expected)
476+
477+
@pytest.mark.skip_xp_backend(
478+
Backend.TORCH_GPU, reason="Puts Pytorch into a bad state."
479+
)
480+
def test_out_of_bound(self, xp: ModuleType):
481+
# Undefined behavior. Either return zero, or raise.
482+
try:
483+
actual = one_hot(xp.asarray([-1, 3]), 3)
484+
except IndexError:
485+
return
486+
expected = xp.asarray([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
487+
xp_assert_equal(actual, expected)
488+
489+
@pytest.mark.parametrize(
490+
"int_dtype",
491+
["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"],
492+
)
493+
def test_int_types(self, xp: ModuleType, int_dtype: str):
494+
dtype = getattr(xp, int_dtype)
495+
x = xp.asarray([0, 1, 2], dtype=dtype)
496+
actual = one_hot(x, 3)
497+
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
498+
xp_assert_equal(actual, expected)
499+
500+
def test_custom_dtype(self, xp: ModuleType):
501+
actual = one_hot(xp.asarray([0, 1, 2], dtype=xp.int32), 3, dtype=xp.bool)
502+
expected = xp.asarray(
503+
[[True, False, False], [False, True, False], [False, False, True]]
504+
)
505+
xp_assert_equal(actual, expected)
506+
507+
def test_axis(self, xp: ModuleType):
508+
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]).T
509+
actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=0)
510+
xp_assert_equal(actual, expected)
511+
512+
actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=-2)
513+
xp_assert_equal(actual, expected)
514+
515+
def test_non_integer(self, xp: ModuleType):
516+
with pytest.raises(TypeError):
517+
_ = one_hot(xp.asarray([1.0]), 3)
518+
519+
def test_device(self, xp: ModuleType, device: Device):
520+
x = xp.asarray([0, 1, 2], device=device)
521+
y = one_hot(x, 3)
522+
assert get_device(y) == device
523+
524+
451525
@pytest.mark.skip_xp_backend(
452526
Backend.SPARSE, reason="read-only backend without .at support"
453527
)

0 commit comments

Comments
 (0)