Description
PyMCModel._data_setter() assumes the underlying PyMC model uses data nodes named "X" and "y", and it always injects a zero-filled y with shape (n_obs, n_treated_units) during prediction. This is a tight coupling to the current internal model templates and can silently break if a subclass uses different data node names or additional observed variables. The method documentation acknowledges this assumption but the interface does not allow customization, so custom models can fail at predict time even if they otherwise follow the PyMCModel API.
Steps to Reproduce
- Create a
PyMCModel subclass that uses different data node names (e.g., "X_design" and "y_obs").
- Fit the model and call
predict().
class CustomModel(PyMCModel):
def build_model(self, X, y, coords):
with self:
self.add_coords(coords)
X_design = pm.Data("X_design", X)
y_obs = pm.Data("y_obs", y)
...
# predict() calls _data_setter(), which updates only "X" and "y"
Expected Behavior
predict() should update the data nodes that the subclass actually defines, or the base class should allow subclasses to override the data node names in a supported way. At minimum, the base behavior should be safe for subclasses that rename nodes.
Actual Behavior
_data_setter() calls pm.set_data({"X": X, "y": zeros(...)}), which fails or silently leaves the real data nodes untouched if they are named differently, causing incorrect predictions or runtime errors.
Environment
- CausalPy version: 0.7.0 (or current main)
- Python version: 3.11+
- OS: macOS (darwin)
Proposed Solution (if known)
More thought and decision making is needed here to balance flexibility, safety, and API stability. Possible directions:
- Explicit configuration: add class attributes or init args like
data_node_names = {"X": "X", "y": "y"} so subclasses can override without re-implementing _data_setter.
- Override contract: make
_data_setter abstract or clearly documented for subclasses to override, with a helper utility for common cases (e.g., set_data_with_obs_ind).
- Graph inspection: discover
pm.Data containers programmatically from the model graph and update any matching dims; safer but more complex and may be brittle across PyMC versions.
- Model-level registration: provide a
register_data_nodes() helper during build_model() to record which data containers should be updated for prediction.
- Validate on fit: store the names of data nodes used in
build_model() and assert they exist during predict(), raising a clear error if "X"/"y" are missing instead of silently misbehaving.
Additional Context
Current implementation:
pm.set_data(
{"X": X, "y": np.zeros((new_no_of_observations, n_treated_units))},
coords={"obs_ind": obs_coords},
)
Description
PyMCModel._data_setter()assumes the underlying PyMC model uses data nodes named"X"and"y", and it always injects a zero-filledywith shape(n_obs, n_treated_units)during prediction. This is a tight coupling to the current internal model templates and can silently break if a subclass uses different data node names or additional observed variables. The method documentation acknowledges this assumption but the interface does not allow customization, so custom models can fail at predict time even if they otherwise follow thePyMCModelAPI.Steps to Reproduce
PyMCModelsubclass that uses different data node names (e.g.,"X_design"and"y_obs").predict().Expected Behavior
predict()should update the data nodes that the subclass actually defines, or the base class should allow subclasses to override the data node names in a supported way. At minimum, the base behavior should be safe for subclasses that rename nodes.Actual Behavior
_data_setter()callspm.set_data({"X": X, "y": zeros(...)}), which fails or silently leaves the real data nodes untouched if they are named differently, causing incorrect predictions or runtime errors.Environment
Proposed Solution (if known)
More thought and decision making is needed here to balance flexibility, safety, and API stability. Possible directions:
data_node_names = {"X": "X", "y": "y"}so subclasses can override without re-implementing_data_setter._data_setterabstract or clearly documented for subclasses to override, with a helper utility for common cases (e.g.,set_data_with_obs_ind).pm.Datacontainers programmatically from the model graph and update any matching dims; safer but more complex and may be brittle across PyMC versions.register_data_nodes()helper duringbuild_model()to record which data containers should be updated for prediction.build_model()and assert they exist duringpredict(), raising a clear error if"X"/"y"are missing instead of silently misbehaving.Additional Context
Current implementation: