Skip to content

Commit

Permalink
Merge pull request #29 from databricks-industry-solutions/neuralforec…
Browse files Browse the repository at this point in the history
…ast-update

update neuralforecast
  • Loading branch information
ryuta-yoshimatsu authored May 18, 2024
2 parents 1da747e + 8413164 commit 61e89d0
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 85 deletions.
67 changes: 41 additions & 26 deletions forecasting_sa/Forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
TimestampType,
BinaryType,
ArrayType,
MapType,
IntegerType,
LongType,
)
from pyspark.sql.functions import lit, avg, min, max, col, posexplode, collect_list, to_date
from forecasting_sa.models.abstract_model import ForecastingRegressor
Expand Down Expand Up @@ -168,7 +169,8 @@ def ensemble(self):
)

(
aggregated_df.withColumn("run_id", lit(self.run_id))
aggregated_df.withColumn(self.conf["group_id"], col(self.conf["group_id"]).cast(StringType()))
.withColumn("run_id", lit(self.run_id))
.withColumn("run_date", lit(self.run_date))
.withColumn("use_case", lit(self.conf["use_case_name"]))
.withColumn("model", lit("ensemble"))
Expand All @@ -177,14 +179,15 @@ def ensemble(self):
.saveAsTable(self.conf["ensemble_scoring_output"])
)

def prepare_data_for_global_model(self, model_conf: DictConfig, path: str) \
-> pd.DataFrame:
df = self.resolve_source(path)
df = df.toPandas()
df, removed = DataQualityChecks(df, self.conf).run()
if model_conf.get("data_prep", "none") == "none":
df[self.conf["group_id"]] = df[self.conf["group_id"]].astype(str)
return df, removed
def prepare_data_for_global_model(self, model_conf: DictConfig, mode: str):
src_df = self.resolve_source("train_data")
src_df, removed = DataQualityChecks(src_df, self.conf, self.spark).run()
if mode == "scoring":
score_df = self.resolve_source("scoring_data")
score_df = score_df.where(~col(self.conf["group_id"]).isin(removed))
src_df = src_df.unionByName(score_df, allowMissingColumns=True)
src_df = src_df.toPandas()
return src_df, removed

def split_df_train_val(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Splits df into train and data, based on backtest months and prediction length.
Expand Down Expand Up @@ -219,7 +222,7 @@ def train_models(self):
try:
model = self.model_registry.get_model(model_name)
# Get training and scoring data
hist_df, removed = self.prepare_data_for_global_model(model_conf, "train_data")
hist_df, removed = self.prepare_data_for_global_model(model_conf, "training")
train_df, val_train_df = self.split_df_train_val(hist_df)
# Train and evaluate new models - results are saved to MLFlow
print(f"Training model: {model}")
Expand Down Expand Up @@ -265,7 +268,7 @@ def train_global_model(
mlflow.log_params(
flatten_nested_parameters(OmegaConf.to_object(tuned_params))
)
train_metrics = self.backtest_global_model(
self.backtest_global_model(
model=tuned_model,
train_df=train_df,
val_df=val_df,
Expand Down Expand Up @@ -367,14 +370,28 @@ def backtest_global_model(
start=train_df[self.conf["date_col"]].max(),
retrain=self.conf["backtest_retrain"]))

res_sdf = self.spark.createDataFrame(res_pdf)\
.withColumn("backtest_window_start_date", to_date("backtest_window_start_date"))\
.withColumn("actual", col("actual").cast("array<double>"))
group_id_dtype = IntegerType() \
if train_df[self.conf["group_id"]].dtype == 'int' else StringType()
print(f"group_id_dtype: {group_id_dtype}")
print(f"train_df[self.conf['group_id']].dtype: {train_df[self.conf['group_id']].dtype}")
schema = StructType(
[
StructField(self.conf["group_id"], group_id_dtype),
StructField("backtest_window_start_date", DateType()),
StructField("metric_name", StringType()),
StructField("metric_value", DoubleType()),
StructField("forecast", ArrayType(DoubleType())),
StructField("actual", ArrayType(DoubleType())),
StructField("model_pickle", BinaryType()),
]
)
res_sdf = self.spark.createDataFrame(res_pdf, schema)

if write:
if self.conf.get("evaluation_output", None):
(
res_sdf.withColumn("run_id", lit(self.run_id))
res_sdf.withColumn(self.conf["group_id"], col(self.conf["group_id"]).cast(StringType()))
.withColumn("run_id", lit(self.run_id))
.withColumn("run_date", lit(self.run_date))
.withColumn("model", lit(model.params.name))
.withColumn("use_case", lit(self.conf["use_case_name"]))
Expand Down Expand Up @@ -442,15 +459,15 @@ def evaluate_local_model(self, model_conf):
)
if self.conf.get("evaluation_output", None) is not None:
(
res_sdf.withColumn("run_id", lit(self.run_id))
res_sdf.withColumn(self.conf["group_id"], col(self.conf["group_id"]).cast(StringType()))
.withColumn("run_id", lit(self.run_id))
.withColumn("run_date", lit(self.run_date))
.withColumn("model", lit(model_conf["name"]))
.withColumn("use_case", lit(self.conf["use_case_name"]))
.withColumn("model_uri", lit(""))
.write.mode("append")
.saveAsTable(self.conf.get("evaluation_output"))
)

res_df = (
res_sdf.groupby(["metric_name"])
.mean("metric_value")
Expand Down Expand Up @@ -506,7 +523,7 @@ def evaluate_one_model(
def evaluate_global_model(self, model_conf):
mlflow_client = mlflow.tracking.MlflowClient()
with mlflow.start_run(experiment_id=self.experiment_id):
hist_df, removed = self.prepare_data_for_global_model(model_conf, "train_data")
hist_df, removed = self.prepare_data_for_global_model(model_conf, "evaluating")
train_df, val_df = self.split_df_train_val(hist_df)
model_name = model_conf["name"]
mlflow.set_tag("model_name", model_conf["name"])
Expand Down Expand Up @@ -599,12 +616,9 @@ def score_local_model(self, model_conf):
src_df.groupby(self.conf["group_id"])
.applyInPandas(score_one_model_fn, schema=output_schema)
)
if not isinstance(res_sdf.schema[self.conf["group_id"]].dataType, StringType):
res_sdf = res_sdf.withColumn(
self.conf["group_id"], col(self.conf["group_id"]).cast(StringType())
)
(
res_sdf.withColumn("run_id", lit(self.run_id))
res_sdf.withColumn(self.conf["group_id"], col(self.conf["group_id"]).cast(StringType()))
.withColumn("run_id", lit(self.run_id))
.withColumn("run_date", lit(self.run_date))
.withColumn("use_case", lit(self.conf["use_case_name"]))
.withColumn("model", lit(model_conf["name"]))
Expand Down Expand Up @@ -641,7 +655,7 @@ def score_one_model(
def score_global_model(self, model_conf):
print(f"Running scoring for {model_conf['name']}...")
champion_model, champion_model_uri = self.get_model_for_scoring(model_conf)
score_df, removed = self.prepare_data_for_global_model(model_conf, "scoring_data")
score_df, removed = self.prepare_data_for_global_model(model_conf, "scoring")
prediction_df, model_fitted = champion_model.forecast(score_df)
if prediction_df[self.conf["date_col"]].dtype.type != np.datetime64:
prediction_df[self.conf["date_col"]] = prediction_df[
Expand All @@ -656,7 +670,8 @@ def score_global_model(self, model_conf):
collect_list(self.conf["date_col"]).alias(self.conf["date_col"]),
collect_list(self.conf["target"]).alias(self.conf["target"]))
(
sdf.withColumn("model", lit(model_conf["name"]))
sdf.withColumn(self.conf["group_id"], col(self.conf["group_id"]).cast(StringType()))
.withColumn("model", lit(model_conf["name"]))
.withColumn("run_id", lit(self.run_id))
.withColumn("run_date", lit(self.run_date))
.withColumn("use_case", lit(self.conf["use_case_name"]))
Expand Down
16 changes: 3 additions & 13 deletions forecasting_sa/data_quality_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,11 @@ class DataQualityChecks:
"""
def __init__(
self,
df: Union[pd.DataFrame, pyspark.sql.DataFrame],
df: pyspark.sql.DataFrame,
conf: DictConfig,
spark: SparkSession = None,
):
if isinstance(df, pd.DataFrame):
self.type = "pandas"
self.df = df
else:
self.type = "spark"
self.df = df.toPandas()
self.df = df.toPandas()
self.conf = conf
self.spark = spark

Expand Down Expand Up @@ -157,11 +152,9 @@ def run(self) -> tuple[Union[pd.DataFrame, pyspark.sql.DataFrame], list]:
conf=self.conf,
max_date=self.df[self.conf["date_col"]].max(),
)

clean_df = self.df.groupby(self.conf["group_id"]).apply(
_multiple_checks_func
)

if isinstance(clean_df.index, pd.MultiIndex):
clean_df = clean_df.drop(
columns=[self.conf["group_id"]], errors="ignore"
Expand All @@ -185,8 +178,5 @@ def run(self) -> tuple[Union[pd.DataFrame, pyspark.sql.DataFrame], list]:
if clean_df.empty:
raise Exception("None of the time series passed the data quality checks.")
print(f"Finished data quality checks...")

if self.type == "spark":
clean_df = self.spark.createDataFrame(clean_df)

clean_df = self.spark.createDataFrame(clean_df)
return clean_df, removed
2 changes: 1 addition & 1 deletion forecasting_sa/models/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,6 @@ def calculate_metrics(
"curr_date": curr_date,
"metric_name": self.params["metric"],
"metric_value": metric_value,
"forecast": pred_df[self.params["target"]].to_numpy(),
"forecast": pred_df[self.params["target"]].to_numpy("float"),
"actual": val_df[self.params["target"]].to_numpy(),
"model_pickle": cloudpickle.dumps(model_fitted)}
Loading

0 comments on commit 61e89d0

Please sign in to comment.