Skip to content

Add model_to_minibatch transformation to convert all pm.Data to pm.Minibatch #7785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented May 15, 2025

Description

A pain point for me when testing different algorithms (e.g. MCMC vs VI) is that I don't want to write a 2nd version of the model with pm.Minibatch on the data.

This PR adds a model transformation that does that for the user. It's the reverse of the remove_minibatched_nodes transformer that @zaxtax implemented recently.

This is a WIP, it doesn't actually work now, because I can't figure out how to rebuild the observed variable with the total_size set correctly. Help wanted.

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7785.org.readthedocs.build/en/7785/

@jessegrabowski jessegrabowski requested a review from zaxtax May 15, 2025 12:28
@ricardoV94
Copy link
Member

This is a WIP, it doesn't actually work now, because I can't figure out how to rebuild the observed variable with the total_size set correctly. Help wanted.

You can use the lower level utility:

def create_minibatch_rv(

Then make that a vanilla observed RV

@ricardoV94
Copy link
Member

Ah you already did that, so your question is how to get total size? Grab the batch shape of the variable and constant fold it without raising if it can't be fully folded

@jessegrabowski
Copy link
Member Author

My real issue was not understanding what needs to be the key and value in the replacements, between:

  1. The model variable
  2. The memo variable
  3. The fgraph variable

@ricardoV94
Copy link
Member

ricardoV94 commented May 15, 2025

the best is usual to replace the whole fgraph ModelObservedRV by a new one. You probably have to discard any dims on the batch dimension which is an input to that op

@jessegrabowski
Copy link
Member Author

I don't really understand what that answer means

@ricardoV94
Copy link
Member

dprint the fgraph and it will perhaps be more obvious what I am mumbling

@jessegrabowski
Copy link
Member Author

The problem i was running into was that I ended up with two beta RVs after doing the replace. Beta was the only RV implicated in the ModelObservedRV sub-graph

@zaxtax
Copy link
Contributor

zaxtax commented May 15, 2025 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants