JAX 0.8 and PyMC v5 compatibility#164
JAX 0.8 and PyMC v5 compatibility#164MilesCranmer wants to merge 21 commits intoexoplanet-dev:mainfrom
Conversation
dfm
left a comment
There was a problem hiding this comment.
Thank you - this looks great! Some small comments/questions inline.
| typedef typename LowRank::Scalar Scalar; | ||
| typedef typename Eigen::internal::plain_col_type<Coeffs>::type CoeffVector; | ||
| typedef typename Eigen::Matrix<Scalar, LowRank::ColsAtCompileTime, RightHandSide::ColsAtCompileTime> Inner; | ||
| typedef typename Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> Inner; |
There was a problem hiding this comment.
From my experience, this change will dramatically impact performance because Eigen won't be able to generate properly vectorized code for small systems. It's really useful to compile for specific sizes! Why did you make this change?
There was a problem hiding this comment.
Thanks, I see. I couldn't get it working initially but this change seemed to do it. I didn't know it would hurt performance though so I'll fix it now.
| typedef typename LowRank::Scalar Scalar; | ||
| typedef typename Eigen::internal::plain_col_type<Coeffs>::type CoeffVector; | ||
| typedef typename Eigen::Matrix<Scalar, LowRank::ColsAtCompileTime, RightHandSide::ColsAtCompileTime> Inner; | ||
| typedef typename Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> Inner; |
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
for more information, see https://pre-commit.ci
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
|
Really struggling to get it working with |
|
Sorry, turns out I won't have enough bandwidth to fix the PR for fixed sized matrices (seems hard), so I will have to leave things as-is for now. Feel free to take this PR as-is if you are okay with the speed hit to get things up to JAX 0.8, or we can just point people to this PR if they need it |
|
No problem! I think it shouldn't be a big deal since it's just for the "general" cases that are only used for predictions. Log prob calculations should still be fast. I'll try and get this merged soon - thanks!! |
|
Oops sorry I forgot to re-run generate.py |
|
Ping @dfm let me know if anything else is left! |
|
It looks like all the CI is failing for various reasons. Can you take a look at those? I'm not totally sure how to have them automatically run for you, but I'll try to be faster to press the button, and you could plausibly run them on your own fork by temporarily adding: here: It also looks like we'll need to update the Python version that we're using on ReadTheDocs. I think that should be a simple as just bumping it here: Line 11 in e7974e4 Do you mind doing that too? |
I think if you put me as collaborator status in the org it might do this? I think it's just a user trust scopes thing. (No need to give me merge rights though) |
|
Good idea! I've invited you - let me know if that works (or doesn't). |
|
@MilesCranmerBot can you make a new PR based on this one and try to get the tests working again? |
|
Actually ugh it might need to make the PR to my account's fork so I can get the CI tested via this PR |
* fix: jax 0.8 compatibility * fix: get pymc working on v5 * deps: lower to python 3.10 * fix: JIT compat of `_do_compute` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * deps: require python 3.11 for jaxlib compat * ci: skip bad nox arg * ci: fix nox target version * deps: fix missing jaxlib install * ci: fix missing jax install for pymc tests * fix: generator will no longer add lines of spaces Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io> * fix: re-generated xla_ops without api.h Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io> * refactor: clean up error handling Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io> * fix: `Primitive` import Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io> * build: update cmake per review Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: put axis into template Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io> * fix: add back nrhs==1 branch * style: remove extra padding * deps: remove upper bound * chore: re-run generator * Fix JAX 0.8 primitive impl path and PyMC JAX registration * fix(jax): avoid removed jax.lib.xla_client import * build: require jax in PEP517 build env for JAX extension * refactor(jax): drop pre-0.8 apply_primitive fallback * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * build: avoid requiring jax on Python<3.11 for docs * Upgrade RTDs Python version from 3.10 to 3.11 --------- Co-authored-by: MilesCranmer <miles.cranmer@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io> Co-authored-by: Dan Foreman-Mackey <foreman.mackey@gmail.com>
This PR upgrades celerite2 to JAX 0.8.x. I updated the jinja templates and re-generated the cpp files. I also upgraded PyMC to v5 and ditched compatibility with PyMC3 since stuff wasn't working anyways and it will be easier to maintain.
I'm unfamiliar with a lot of the lower level JAX stuff, so I am not confident about some of this PR, especially the FFI stuff. A good look over by someone else would be helpful.
Also, PyMC v5 seemed to need this
@jax_funcify.register(_CeleriteOp)thing but I am not 100% sure about this.Paging @dfm for review.