-
-
Notifications
You must be signed in to change notification settings - Fork 272
Update to fast sampling notebook #794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This pull request updates the fast sampling notebook to provide clearer guidance on using various NUTS samplers in PyMC and their performance characteristics while also updating dependency versions in the configuration file.
- Updated dependency versions and configuration in pixi.toml
- Revised and expanded sampling examples in the fast sampling notebook, including performance comparison details
- Enhanced installation requirements and advanced usage instructions
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
pixi.toml | Updated dependency versions and switched the configuration section from [project] to [workspace] |
examples/samplers/fast_sampling_with_jax_and_numba.myst.md | Revised sampler documentation, restructured performance comparisons, and updated kernel and watermark configurations |
Comments suppressed due to low confidence (1)
pixi.toml:1
- Changing the configuration section from [project] to [workspace] may alter the expected build behavior. Please verify and update related tooling and documentation to ensure compatibility with this new structure.
[workspace]
@@ -19,18 +19,58 @@ | |||
"cell_type": "markdown", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this sampler can compile the underlying model to different backends (C, Numba, or JAX) using PyTensor's compilation system via the compile_kwargs
parameter, it maintains Python overhead that can limit performance for large models
One of the masters can correct me, but I think that the python overhead limits performance on smaller models. For big models, most of the compute time is going to be spent inside the logp
or dlogp
function, so doing the python looping will be a lower relative cost.
Also don't forget the poor Torch backend :)
Nutpie is PyMC's cutting-edge performance sampler. Written in Rust, it eliminates Python overhead and provides exceptional performance for continuous models. The Numba backend typically offers the highest performance for most use cases, while the JAX backend excels with very large models and provides GPU acceleration capabilities. Nutpie is particularly well-suited for production workflows where sampling speed is critical.
I'd also mention that Nutpie has the SOTA NUTS adaptation algorithm, so it gets into the typical set much faster and you can get away with many fewer tuning steps as a result. That means even more speedup!
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch is just a compile mode, yes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes @jessegrabowski is correct that performance penalty should be mostly for small models, (ignoring the questions of adaptation, that are not related to the language used, but the algorithm)
@@ -19,18 +19,58 @@ | |||
"cell_type": "markdown", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've found many models where C is actually faster.
I know this is just generic advice, but I'd try to tell people that the default PyMC backend remains a great starting point, but then you can tinker and specialize if needed. Some random thoughts:
- PyMC is robust and flexible, and you can compile to all of the backends. Works 100% out of the box for all PyMC/pytensor features. Numba is really great, but can have longer compile times. Jax can deadlock if you don't carefully set the
mp_ctx
, and requiresfreeze_dims_and_data
. - Nutpie is powerful and fast, but relatively new and still has sharp edges. Continuous models only! Requires gradients to be well defined everywhere, so can fail to start on gnarly models without hand-selecting inital points. Has access to both numba and jax modes. Jax mode very good for scan models.
- Jax samplers (numpyro/blackjax) offer speedup similar to Nutpie, but without the SOTA tuning algorithm. Numpyro is more battle-tested. Blackjax doesn't have a progress bar so I never use it.
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we want to leave Blackjax out of this altogether.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have it so it's good to mention. But I don't think it falls neatly into the schema of "do this one in case x and this one in case y".
It would be more interesting to talk about if we allowed access to arbitrary blackjax step samplers. But that's obviously nothing to do with this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't suggest numpyro for small models, if it's small just go with the default, unless you really need to fit them fast for some reason. In that case you should probably use nutpie with numba + pre-compile the model once. That's a bit out of scope here.
@@ -19,18 +19,58 @@ | |||
"cell_type": "markdown", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line #11. multiprocessing.set_start_method("spawn", force=True)
Comment on why this is necessary?
Also would be good to add a warning filter for the RuntimeWarning: os.fork() was called. os.fork()
that jax spams out
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditching this.
@@ -19,18 +19,58 @@ | |||
"cell_type": "markdown", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was inherited from the previous version of the notebook. No idea. Open to a better (non-simulated) suggestion.
@@ -19,18 +19,58 @@ | |||
"cell_type": "markdown", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think just to give the reader an idea of what the data look like. Not needed; just something I did not remove from the current version of the notebook.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's inherited I'm not going to make a fuss, but I'd still consider removing it.
@@ -19,18 +19,58 @@ | |||
"cell_type": "markdown", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -19,18 +19,58 @@ | |||
"cell_type": "markdown", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should set xla_force_num_host_devices=8
to make the comparison fair. You're sampling sequentially in this example.
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought chain_method="parallel"
was the default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I've added numpyro.set_host_device_count(8)
which seems to do the trick.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also thought it was the default, but apparently not.
Internally, numpyro.set_host_device_count(8)
just does os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
. My preference is that we show setting that flag directly, rather than using the numpyro function. For a long time I thought that numpyro function did something special, so I was importing that whole package even when I was sampling with nutpie.
@@ -19,18 +19,58 @@ | |||
"cell_type": "markdown", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't really need a scenario for this though! I pretty much always use numba mode these days.
Would be nice to do the timings on all 3 backends and show the results in a little table.
You can also mention that you can globally request pytensor to compile everything to a specific backend by putting e.g:
import pytensor mode = pytensor.compile.mode.get_mode('NUMBA') pytensor.config.mode = mode
At the top of a script/notebook. Then it's not required to pass compile_kwargs, it will default to numba always (including for post-sampling stuff like sample_posterior_predictive
,which can be important)
Reply via ReviewNB
@@ -19,18 +19,58 @@ | |||
"cell_type": "markdown", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to just run all these and make a little table with timings (compile, sample, wall, es/s)
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thought about that, but thought it would be clearer to have each model run in its own cell. Let me think about how best to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also think every model should be run in it's own cell. But we could still collect timings as variables by doing time.time()
before and after the call to sample. There also might be some profiler magic @ricardoV94 or @aseyboldt know to get the compile time vs the sampling time. My proposed method would just lump them together, which will bias against numba for example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lumping them together isn't terrible, as it will give an idea of timing in real-world usage. I guess I can increase the number of samples to more heavily weight sampling time.
@@ -19,18 +19,58 @@ | |||
"cell_type": "markdown", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to mention that many step samplers compile their own functions that are not the same as the logp function compiled by the model, and that compile_kwargs
are not currently propagated to these. So in cases were you're really optimizing for speed, it can be necessary to manually declare step samplers so you can pass compile_kwargs
to them.
For the record, BinaryGibbsMetropolis
is not such a sampler -- it will respect compile_kwargs
passed to pm.sample
.But not all will (e.g. Metropolis
and Slice
)
Reply via ReviewNB
View / edit / reply to this conversation on ReviewNB ricardoV94 commented on 2025-05-30T12:42:32Z This sampler is required when working with models that contain discrete variables, as it's the only option that supports non-gradient based samplers like Slice and Metropolis. I would rewrite as,
|
View / edit / reply to this conversation on ReviewNB ricardoV94 commented on 2025-05-30T12:42:33Z How much is this numpyro ess due to luck (random seed)? I find it suspicious it does the best, and if it does I'm sure we can get @aseyboldt to come and make nutpie kick ass instead |
Provides updated guidance on how and when to use the various NUTS samplers.
Helpful links
📚 Documentation preview 📚: https://pymc-examples--794.org.readthedocs.build/en/794/