Skip to content

Commit aa78891

Browse files
committed
fix for column pollution
1 parent ab12f7d commit aa78891

File tree

2 files changed

+721
-39
lines changed

2 files changed

+721
-39
lines changed

python/src/robyn/modeling/ridge/ridge_data_builder.py

+26-13
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,32 @@ def __init__(self, mmm_data, featurized_mmm_data):
1212
self.logger = logging.getLogger(__name__)
1313

1414
def _prepare_data(self, params: Dict[str, float]) -> Tuple[pd.DataFrame, pd.Series]:
15-
# Get the dependent variable
16-
# Check if 'dep_var' is in columns
17-
if "dep_var" in self.featurized_mmm_data.dt_mod.columns:
18-
# Rename 'dep_var' to the specified value
19-
self.featurized_mmm_data.dt_mod = self.featurized_mmm_data.dt_mod.rename(
20-
columns={"dep_var": self.mmm_data.mmmdata_spec.dep_var}
21-
)
22-
y = self.featurized_mmm_data.dt_mod[self.mmm_data.mmmdata_spec.dep_var]
23-
24-
# Select all columns except the dependent variable
25-
X = self.featurized_mmm_data.dt_mod.drop(
26-
columns=[self.mmm_data.mmmdata_spec.dep_var]
27-
)
15+
"""Prepare data for ridge regression, handling dependent variable and excluding date columns"""
16+
# Get the dependent variable, handling both possible column names
17+
dep_var = self.mmm_data.mmmdata_spec.dep_var
18+
if dep_var not in self.featurized_mmm_data.dt_mod.columns:
19+
# If dep_var column doesn't exist, try 'dep_var'
20+
if 'dep_var' in self.featurized_mmm_data.dt_mod.columns:
21+
# Rename 'dep_var' to the specified value
22+
self.featurized_mmm_data.dt_mod = self.featurized_mmm_data.dt_mod.rename(
23+
columns={"dep_var": dep_var}
24+
)
25+
y = self.featurized_mmm_data.dt_mod[dep_var]
26+
else:
27+
raise KeyError(f"Could not find dependent variable column. Expected either '{dep_var}' or 'dep_var' in columns: {self.featurized_mmm_data.dt_mod.columns.tolist()}")
28+
else:
29+
y = self.featurized_mmm_data.dt_mod[dep_var]
30+
31+
# Select all columns except the dependent variable and date columns
32+
exclude_cols = ['ds'] # Always exclude 'ds' if present
33+
if dep_var in self.featurized_mmm_data.dt_mod.columns:
34+
exclude_cols.append(dep_var)
35+
if 'dep_var' in self.featurized_mmm_data.dt_mod.columns:
36+
exclude_cols.append('dep_var')
37+
38+
# Only drop columns that actually exist in the dataframe
39+
exclude_cols = [col for col in exclude_cols if col in self.featurized_mmm_data.dt_mod.columns]
40+
X = self.featurized_mmm_data.dt_mod.drop(columns=exclude_cols)
2841

2942
# Convert date columns to numeric (number of days since the earliest date)
3043
date_columns = X.select_dtypes(include=["datetime64", "object"]).columns

0 commit comments

Comments
 (0)