WIP ENH: setdiff1d for Dask and jax.jit#124
Draft
crusaderky wants to merge 3 commits intodata-apis:mainfrom
Draft
WIP ENH: setdiff1d for Dask and jax.jit#124crusaderky wants to merge 3 commits intodata-apis:mainfrom
setdiff1d for Dask and jax.jit#124crusaderky wants to merge 3 commits intodata-apis:mainfrom
Conversation
2900169 to
0bc3adf
Compare
a952ede to
028441c
Compare
Contributor
Author
|
@rgommers confirmed offline his preference for delaying indefinitely. His reasoning is that at some point JAX should support unknown shapes and the issue of niche functions that can't work until then should be marginal enough not to warrant urgent attention. |
Member
Worth noting that this was hit for Dask in the cross-linked sklearn PR |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #116
Needs more thorough unit tests + performance benchmarks.
This function's output is of unknown shape, so with the previous API it will never work in jax.jit.
There are a few options:
jax.jitand you need to hack your way around it with ENH:lazy_apply#86 (comment).I'm not a fan of this because UX is very painful as it forces the user to think in graphs.
fill_value. iff running inside the jax.jit, quietly return a longer array padded with it.I'm not happy about this because it causes jax.jit to quietly diverge from other backends and users will spend a lot of time debugging.
sizeandfill_value.sizebecomes mandatory when running inside jax.jit. This is the same design asjax.numpy.unique_values.This also allows having a known-shape output in Dask. However, implementing it for Dask is fairly complicated.
sizeandfill_value.sizeis mandatory when running inside jax.jit and disregarded otherwise. Again, this will cause bugs in the user code that only appear in jax.jit, but at least it demands an initial explicit user intervention. This is the simplest to implement; unsure on the UX. It also has the advantage of not sacrificing performance on other backends. If in the future jax.jit will support arrays of unknown size, it becomes easy to deprecate it as we said that the output size requested by the user may be disregarded anyway.My current favourite is (4).
@rgommers you previously said, talking about functions with the same problem in scipy, that you prefer (1) to (3) because of not being able to retract the API in the future. What's your opinion on (4)?
CC @lucascolley