Skip to content

PyMCModel._data_setter hard-codes X/y data nodes #663

@drbenvincent

Description

@drbenvincent

Description

PyMCModel._data_setter() assumes the underlying PyMC model uses data nodes named "X" and "y", and it always injects a zero-filled y with shape (n_obs, n_treated_units) during prediction. This is a tight coupling to the current internal model templates and can silently break if a subclass uses different data node names or additional observed variables. The method documentation acknowledges this assumption but the interface does not allow customization, so custom models can fail at predict time even if they otherwise follow the PyMCModel API.

Steps to Reproduce

  1. Create a PyMCModel subclass that uses different data node names (e.g., "X_design" and "y_obs").
  2. Fit the model and call predict().
class CustomModel(PyMCModel):
    def build_model(self, X, y, coords):
        with self:
            self.add_coords(coords)
            X_design = pm.Data("X_design", X)
            y_obs = pm.Data("y_obs", y)
            ...

# predict() calls _data_setter(), which updates only "X" and "y"

Expected Behavior

predict() should update the data nodes that the subclass actually defines, or the base class should allow subclasses to override the data node names in a supported way. At minimum, the base behavior should be safe for subclasses that rename nodes.

Actual Behavior

_data_setter() calls pm.set_data({"X": X, "y": zeros(...)}), which fails or silently leaves the real data nodes untouched if they are named differently, causing incorrect predictions or runtime errors.

Environment

  • CausalPy version: 0.7.0 (or current main)
  • Python version: 3.11+
  • OS: macOS (darwin)

Proposed Solution (if known)

More thought and decision making is needed here to balance flexibility, safety, and API stability. Possible directions:

  • Explicit configuration: add class attributes or init args like data_node_names = {"X": "X", "y": "y"} so subclasses can override without re-implementing _data_setter.
  • Override contract: make _data_setter abstract or clearly documented for subclasses to override, with a helper utility for common cases (e.g., set_data_with_obs_ind).
  • Graph inspection: discover pm.Data containers programmatically from the model graph and update any matching dims; safer but more complex and may be brittle across PyMC versions.
  • Model-level registration: provide a register_data_nodes() helper during build_model() to record which data containers should be updated for prediction.
  • Validate on fit: store the names of data nodes used in build_model() and assert they exist during predict(), raising a clear error if "X"/"y" are missing instead of silently misbehaving.

Additional Context

Current implementation:

pm.set_data(
    {"X": X, "y": np.zeros((new_no_of_observations, n_treated_units))},
    coords={"obs_ind": obs_coords},
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    refactorRefactor, clean up, or improvement with no visible changes to the user

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions