Replies: 3 comments 2 replies
-
|
Hey Mike! I don't usually monitor the discussion page so I missed this. I set up notifications though. Statespace models get slow for two reasons:
Now, that's all fine, but then add to this:
Just as an example, the default NUTS settings allow for a max tree depth of 10. The tree depth controls how long the Hamiltonian physics simulation is allowed to carry on; at each tree node you're doing one logp and one gradient evaluation. So in the worst case, to get one sample, you have to do 1024 logp+gradient evaluations. Say you want 2000 samples, and you have 100 time steps, and you have 10 hidden states. So in the absolute worst case you're inverting a 10x10 matrix 2000 * 1024 * 100 = 200 billion times for the logp computation, but then the gradient of a solve involves a solve so that's another 200 billion inversions. And that's just the solves, there's also a bunch of matrix multiplications which are also about You might thing "wow that's a lot of linear algebra, a GPU should be great at that", but then the loop comes back to bite you. You can't parallelize the loop over the T timesteps, nor can you parallelize over the 2000 draws (sort of, you can have multiple chains of course but set that aside). Ok that context aside, why does JAX help? It all boils down to the loop. JAX seems to have an extremely good implementation of So if you want more speed, what can we do?
So. That was a longer answer than perhaps you were hoping for. It also has essentially no actionable help for your situation, aside from the advice to switch to using approximate inference (which is indeed my suggestion for you). But maybe this big list will encourage you to contribute :) |
Beta Was this translation helpful? Give feedback.
-
|
You may want to try running with numba / nutpie. JAX isn't always that great in CPU. You can benchmark how long a logp_dlogp eval take and calibrate your expectations from there. |
Beta Was this translation helpful? Give feedback.
-
I think this is a killer feature and the draft PR seems already pretty good. What do we need to push this forward? I would be happy to help :) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi folks:
I've been experimenting with the pymc statespace. One of the downsides I read about is the incompatibility of statespace, at present, with faster samplers. Having said that, JAX could be fast if only I had the right setup, I've read.
I'm using an AMD GPU and a AMD 8 core CPU (16 hyperthreaded), but I'm also using a windows machine. Is there a sure way to speed up JAX, other than having a simple, well specified model?
For example, I've seen where using a special docker image running Linux with a special AMD gpu config could help. I can't afford to get an Nvidia card right now.
I'll point out that this is my current script config:
jax.config.update("jax_platform_name", "cpu")
import numpyro
numpyro.set_host_device_count(8)
Thank you so much for your time!
-Mike
Beta Was this translation helpful? Give feedback.
All reactions