Skip to content

Commit 326d874

Browse files
committed
crusadersky's idea
1 parent 7736d2c commit 326d874

File tree

2 files changed

+14
-18
lines changed

2 files changed

+14
-18
lines changed

src/array_api_extra/_delegation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
is_torch_array,
1717
is_torch_namespace,
1818
)
19+
from ._lib._utils._compat import device as get_device
1920
from ._lib._utils._helpers import asarrays
2021
from ._lib._utils._typing import Array, DType
2122

@@ -169,7 +170,9 @@ def one_hot(
169170
msg = "x must have an integral dtype."
170171
raise TypeError(msg)
171172
if dtype is None:
172-
dtype = xp.__array_namespace_info__().default_dtypes(device=get_device(x))["real floating"]
173+
dtype = xp.__array_namespace_info__().default_dtypes(device=get_device(x))[
174+
"real floating"
175+
]
173176
# Delegate where possible.
174177
if is_jax_namespace(xp):
175178
assert is_jax_array(x)
@@ -192,8 +195,6 @@ def one_hot(
192195
num_classes,
193196
dtype=dtype,
194197
xp=xp,
195-
supports_fancy_indexing=is_numpy_namespace(xp),
196-
supports_array_indexing=is_dask_namespace(xp),
197198
)
198199

199200
if axis != -1:

src/array_api_extra/_lib/_funcs.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,6 @@ def one_hot(
380380
/,
381381
num_classes: int,
382382
*,
383-
supports_fancy_indexing: bool = False,
384-
supports_array_indexing: bool = False,
385383
dtype: DType,
386384
xp: ModuleType,
387385
) -> Array: # numpydoc ignore=PR01,RT01
@@ -394,19 +392,16 @@ def one_hot(
394392
# specification.
395393
msg = "x must have a concrete size."
396394
raise TypeError(msg)
397-
out = xp.zeros((x.size, num_classes), dtype=dtype, device=_compat.device(x))
398-
x_flattened = xp.reshape(x, (-1,))
399-
if supports_fancy_indexing:
400-
out = at(out)[xp.arange(x_size), x_flattened].set(1)
401-
else:
402-
for i in range(x_size):
403-
x_i = x_flattened[i]
404-
if not supports_array_indexing:
405-
x_i = int(x_i)
406-
out = at(out)[i, x_i].set(1)
407-
if x.ndim != 1:
408-
out = xp.reshape(out, (*x.shape, num_classes))
409-
return out
395+
# TODO: Benchmark whether this is faster on the numpy backend:
396+
# x_flattened = xp.reshape(x, (-1,))
397+
# out = xp.zeros((x.size, num_classes), dtype=dtype, device=_compat.device(x))
398+
# at(out)[xp.arange(x_size), x_flattened].set(1)
399+
# if x.ndim != 1:
400+
# out = xp.reshape(out, (*x.shape, num_classes))
401+
out = x[..., None] == xp.arange(
402+
num_classes, dtype=x.dtype, device=_compat.device(x)
403+
)
404+
return xp.astype(out, dtype)
410405

411406

412407
def create_diagonal(

0 commit comments

Comments
 (0)