ENH: add new function one_hot#306
Conversation
|
RE dtype, I think data-apis/array-api#848 will give us something slightly cleaner down the line. In SciPy we have been using https://github.com/scipy/scipy/blob/main/scipy/_lib/_array_api.py#L399 with |
|
Yeah, I'm not 100% sure how those links will help in this case? We're not casting one thing to another kind, or promoting to float. We just want the default float dtype irrespective of what was passed in. You may want to consider giving names to: I'm not sure though, and these are one-liners. |
the point is that instead of writing dtype = xp.empty(()).dtype
x = xp.zeros(..., dtype=dtype)we would have x = xp.zeros(...)
x = xp.astype(x, 'real floating') |
|
Ah, right. Okay. I guess you mean Also, could you help me solve the Dask errors? This is all foreign to me. And how do I make an array with a non-concrete size? ( |
Nope, data-apis/array-api#848 suggests overloading the |
Oh, no worries, I don't have a strong opinion. Just trying to keep up with all the planned changes 😄 Did you see my edits? I could use some guidance with the Dask errors. |
to me this doesn't make much sense. Why shouldn't it be |
I'd rather follow what the libraries are doing than force double conversion for delegated code. If you do that, most people would end up having to write their own In general, the reason it's not bool is because these values often serve as the inputs to machine learning algorithms. |
|
@lucascolley Ready for your review |
|
[EDIT] just realised that I have had an unsent review hanging for the last 3 days. Apologies. What about that should be all you need to do to implement |
| Parameters | ||
| ---------- | ||
| x : array | ||
| An array with integral dtype and concrete size (``x.size`` cannot be `None`). |
There was a problem hiding this comment.
I think with the new algorithm none sizes are supported (needs a test though)
| ] | ||
| # Delegate where possible. | ||
| if is_jax_namespace(xp): | ||
| assert is_jax_array(x) |
There was a problem hiding this comment.
| assert is_jax_array(x) |
We typically don't do this. It's needlessly expensive.
There was a problem hiding this comment.
I agree that we should be consistent with the rest of the package, so I'll remove it. However, you may want to revisit this policy since with jitted code such calls are basically free. This is true for Jax and Pytorch, which are the backends that I added assertions with. Also, assertions disappear when optimizations are turned on.
The assertions have the benefit of improve type checking.
crusaderky
left a comment
There was a problem hiding this comment.
This is almost ready to go; a few nits below.
It's worth noting that the current implementation makes it impossible to write
symbols, idx = xp.unique_inverse(x)
xpx.one_hot(idx, symbols.size)as a pattern to build a one-hot map of arbitrary symbols on Dask, unless you know in advance the maximum number of unique symbols.
This could be fixed in a follow-up but I'm unsure about real-life interest in it.
|
looks like there was a rebase hiccup |
lucascolley
left a comment
There was a problem hiding this comment.
thanks both, looks great!
|
Thanks @lucascolley and @crusaderky for the thorough and quick reviews! |
crusaderky
left a comment
There was a problem hiding this comment.
Only one nit; see above
|
looks like there was a rebase hiccup |
|
thanks again! |
Fixes #305
Questions:
xp.empty(()).dtype?