Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust MrVI running hyperparameters for 1000s of samples #3145

Open
PierreBoyeau opened this issue Jan 14, 2025 · 7 comments · May be fixed by #3146
Open

Adjust MrVI running hyperparameters for 1000s of samples #3145

PierreBoyeau opened this issue Jan 14, 2025 · 7 comments · May be fixed by #3146
Assignees

Comments

@PierreBoyeau
Copy link
Contributor

The default execution of MrVI core functions relies on vmap, which fastens execution and increases memory usage.
This memory cost is not sustainable in scenarios with 1000s of samples.
A first step in this direction would be to disable vmap by default. @justjhong @canergen what do you think?

@PierreBoyeau PierreBoyeau self-assigned this Jan 14, 2025
@PierreBoyeau PierreBoyeau linked a pull request Jan 14, 2025 that will close this issue
@justjhong
Copy link
Contributor

I think a majority of users will use defaults without paying attention to arguments like use_vmap. How about we do an automatic change to not use vmap when the number of samples exceeds 1000? Then, we can also display a warning message that we did so if the user didn't explicitly pass in use_vmap=True (e.g., "vmap parallelized execution has been disabled automatically since the number of samples exceeds 1000. If you would still like to use vmap, explicitly pass in use_vmap=True"). To get this behavior we have the default as use_vmap=None and treat this case differently.

@canergen
Copy link
Member

canergen commented Jan 14, 2025

Best case you do a for loop over the array in the vmap dimension of 100 each (not sure there is a pre-specified subset for this thing). Just assume that you also have sped-up for large sample sizes but the speed up likely ceils for some size over vmap.

@PierreBoyeau
Copy link
Contributor Author

Thanks for the feedback. I added two things:

  1. more informative tracebacks to let users know use_vmap=False could fix OOM errors.
  2. a change in the default to use_vmap='auto' to automatically determine whether vmap makes sense.

@canergen I have tried these batched vmaps in the VIVS code. One problem is that this significantly affects the code readability. I prefer to avoid implementing these strategies, given how packed _model.py is. Let me know what you think!

@canergen
Copy link
Member

Sounds reasonable. How long does it take now for 1000 samples and 10k cells?

@VladimirShitov
Copy link

Hey! I hope you don't mind me intervening in this discussion. I am running MrVI on large-scale datasets (700 k - 1.5 Mln cells, hundreds of samples). The method seems to be working nicely, but the scaling makes me very sad. Running get_local_sample_distances() often fails even on a powerful compute node (OOM), and the estimated running time is hours to tens of hours.

Do you have any recommendations on how to run the method efficiently?

Also related to #3166

@justjhong
Copy link
Contributor

Hi @VladimirShitov, thanks for your comment. We think JAX updates have caused both issues (scaling wrt memory and time). I made an issue for the time scaling issue here #3179. We will have to figure out how to debug the issue w/ JAX's updates or as a temporary fix pin to an older version of JAX.

@VladimirShitov
Copy link

Thanks @justjhong ! Looking forward to the updates :) Also, lmk if I can help with testing scalability with the datasets I have at hand

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants