Skip to content

ENH: add new function one_hot #306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025
Merged

Conversation

NeilGirdhar
Copy link
Contributor

@NeilGirdhar NeilGirdhar commented May 29, 2025

Fixes #305

Questions:

  • What should the default output dtype be? Currently, it's the default float type on Jax/PyTorch, or else float64. Should I try to get the default float type on the platform: xp.empty(()).dtype?

@lucascolley lucascolley changed the title Add one_hot ENH: add new function one_hot May 29, 2025
@lucascolley lucascolley added enhancement New feature or request new function labels May 29, 2025
@lucascolley lucascolley added this to the 0.8.0 milestone May 29, 2025
@NeilGirdhar NeilGirdhar force-pushed the onehot branch 3 times, most recently from 9b5f393 to 02544de Compare May 30, 2025 05:47
@lucascolley
Copy link
Member

RE dtype, I think data-apis/array-api#848 will give us something slightly cleaner down the line.

In SciPy we have been using https://github.com/scipy/scipy/blob/main/scipy/_lib/_array_api.py#L399 with force_floating=True. Taking xp.empty(()).dtype is probably fine for now though.

@NeilGirdhar
Copy link
Contributor Author

Yeah, I'm not 100% sure how those links will help in this case? We're not casting one thing to another kind, or promoting to float. We just want the default float dtype irrespective of what was passed in. You may want to consider giving names to:

xp.asarray(1j).dtype  # Default complex
xp.asarray(1).dtype  # Default int
xp.empty(()).dtype  # Default float

I'm not sure though, and these are one-liners.

@lucascolley
Copy link
Member

I'm not 100% sure how those links will help in this case?

the point is that instead of writing

dtype = xp.empty(()).dtype
x = xp.zeros(..., dtype=dtype)

we would have

x = xp.zeros(...)
x = xp.astype(x, 'real floating')

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented May 30, 2025

Ah, right. Okay. I guess you mean askind, right?

Also, could you help me solve the Dask errors? This is all foreign to me.

And how do I make an array with a non-concrete size? (x.size=None)

@lucascolley
Copy link
Member

I guess you mean askind, right?

Nope, data-apis/array-api#848 suggests overloading the dtype parameter of astype for this. Feel free to comment over there if you disagree!

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented May 30, 2025

suggests overloading the dtype parameter of astype for this. Feel free to comment over there if you disagree!

Oh, no worries, I don't have a strong opinion. Just trying to keep up with all the planned changes 😄

Did you see my edits? I could use some guidance with the Dask errors.

@crusaderky
Copy link
Contributor

  • What should the default output dtype be? Currently, it's the default float type on Jax/PyTorch, or else float64. Should I try to get the default float type on the platform: xp.empty(()).dtype?

to me this doesn't make much sense. Why shouldn't it be bool?

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented May 30, 2025

  • What should the default output dtype be? Currently, it's the default float type on Jax/PyTorch, or else float64. Should I try to get the default float type on the platform: xp.empty(()).dtype?

to me this doesn't make much sense. Why shouldn't it be bool?

I'd rather follow what the libraries are doing than force double conversion for delegated code. If you do that, most people would end up having to write their own one_hot method in order to avoid it.

In general, the reason it's not bool is because these values often serve as the inputs to machine learning algorithms.

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented May 31, 2025

@lucascolley Ready for your review

@NeilGirdhar NeilGirdhar force-pushed the onehot branch 2 times, most recently from ecc8b40 to 535ae42 Compare May 31, 2025 03:54
@crusaderky
Copy link
Contributor

crusaderky commented Jun 1, 2025

[EDIT] just realised that I have had an unsent review hanging for the last 3 days. Apologies.

What about x[..., None] == xp.arange(num_classes, dtype=x.dtype, device=_compact.device(x))?

that should be all you need to do to implement one_hot in a way that is both performant and Array API compliant.
(the numpy special case may be faster; needs benchmarking).

@NeilGirdhar NeilGirdhar force-pushed the onehot branch 3 times, most recently from 0989cc4 to ab421e9 Compare June 2, 2025 06:07
Parameters
----------
x : array
An array with integral dtype and concrete size (``x.size`` cannot be `None`).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with the new algorithm none sizes are supported (needs a test though)

@NeilGirdhar NeilGirdhar force-pushed the onehot branch 2 times, most recently from 60d3b3b to 620372c Compare June 2, 2025 12:42
Copy link
Contributor

@crusaderky crusaderky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is almost ready to go; a few nits below.

It's worth noting that the current implementation makes it impossible to write

symbols, idx = xp.unique_inverse(x)
xpx.one_hot(idx, symbols.size)

as a pattern to build a one-hot map of arbitrary symbols on Dask, unless you know in advance the maximum number of unique symbols.
This could be fixed in a follow-up but I'm unsure about real-life interest in it.

@lucascolley
Copy link
Member

looks like there was a rebase hiccup

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks both, looks great!

@NeilGirdhar
Copy link
Contributor Author

Thanks @lucascolley and @crusaderky for the thorough and quick reviews!

Copy link
Contributor

@crusaderky crusaderky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only one nit; see above

@NeilGirdhar NeilGirdhar force-pushed the onehot branch 3 times, most recently from 70024a6 to cf31c80 Compare June 3, 2025 08:17
@NeilGirdhar NeilGirdhar force-pushed the onehot branch 2 times, most recently from f2c0356 to fa0e4a4 Compare June 3, 2025 08:58
@lucascolley
Copy link
Member

looks like there was a rebase hiccup

@lucascolley lucascolley merged commit 61b512e into data-apis:main Jun 3, 2025
9 checks passed
@lucascolley
Copy link
Member

thanks again!

@NeilGirdhar NeilGirdhar deleted the onehot branch June 3, 2025 09:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request new function
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: new function one_hot
3 participants