From 25ae5da3096b6d09733c43084a9fbf787b63acaa Mon Sep 17 00:00:00 2001 From: Jerome Dockes Date: Fri, 19 Jun 2026 17:43:47 +0200 Subject: [PATCH] add a pixi task for extracting functions from feature_engineering.py --- .gitignore | 1 + content/python_files/extract_defs.py | 74 ------------------- content/python_files/feature_engineering.py | 2 +- .../python_files/next_horizon_prediction.py | 73 +++++------------- pixi.toml | 9 ++- utils/extract_defs.py | 35 +++++++++ 6 files changed, 60 insertions(+), 134 deletions(-) delete mode 100644 content/python_files/extract_defs.py create mode 100755 utils/extract_defs.py diff --git a/.gitignore b/.gitignore index e757fa6..8fe30a1 100644 --- a/.gitignore +++ b/.gitignore @@ -84,6 +84,7 @@ content/results # jupyter-book build book/_build content/notebooks +content/python_files/*_lib.py # jupyterlite build jupyterlite/.jupyterlite.doit.db diff --git a/content/python_files/extract_defs.py b/content/python_files/extract_defs.py deleted file mode 100644 index db7ab47..0000000 --- a/content/python_files/extract_defs.py +++ /dev/null @@ -1,74 +0,0 @@ -import ast -import sys -from pathlib import Path - - -def extract_defs(source, names=None): - """Extract imports and function/class definitions, discarding other statements.""" - selected = set(names or []) - lines = source.splitlines(keepends=True) - tree = ast.parse(source) - selected_nodes = [] - - for node in tree.body: - if isinstance(node, (ast.FunctionDef, ast.ClassDef)): - if selected and node.name not in selected: - continue - selected_nodes.append(node) - - used_names = set() - if selected: - for node in selected_nodes: - for child in ast.walk(node): - if isinstance(child, ast.Name): - used_names.add(child.id) - - import_snippets = [] - def_snippets = [] - for node in tree.body: - if isinstance(node, (ast.Import, ast.ImportFrom)): - if not selected: - import_snippets.append("".join(lines[node.lineno - 1 : node.end_lineno])) - continue - - imported_names = [] - for alias in node.names: - if isinstance(node, ast.Import): - imported_names.append(alias.asname or alias.name.split(".")[0]) - else: - imported_names.append(alias.asname or alias.name) - if any(name in used_names for name in imported_names): - import_snippets.append("".join(lines[node.lineno - 1 : node.end_lineno])) - elif isinstance(node, (ast.FunctionDef, ast.ClassDef)): - if selected and node.name not in selected: - continue - lineno = node.decorator_list[0].lineno if node.decorator_list else node.lineno - def_snippets.append("".join(lines[lineno - 1 : node.end_lineno])) - out = "".join(sorted(set(import_snippets))) + "\n\n" + "\n\n".join(def_snippets) - return out - - -if __name__ == "__main__": - args = sys.argv[1:] - file_name = args[0] if args else "feature_engineering.py" - out_file = None - names = None - - i = 1 - while i < len(args): - if args[i] == "--out" and i + 1 < len(args): - out_file = args[i + 1] - i += 2 - continue - if args[i] == "--names" and i + 1 < len(args): - names = [name for name in args[i + 1].split(",") if name] - i += 2 - continue - i += 1 - - source = Path(file_name).read_text(encoding="utf-8") - extracted = extract_defs(source, names=names) - if out_file: - Path(out_file).write_text(extracted, encoding="utf-8") - else: - print(extracted) diff --git a/content/python_files/feature_engineering.py b/content/python_files/feature_engineering.py index b34afb8..d0017aa 100644 --- a/content/python_files/feature_engineering.py +++ b/content/python_files/feature_engineering.py @@ -297,7 +297,7 @@ def get_X_y(prediction_time, electricity_load_history, horizons, mode=skrub.eval def add_target_time(df, horizon): return df.with_columns( - (pl.col("target_time") + pl.duration(hours=horizon)).alias("target_time") + (pl.col("prediction_time") + pl.duration(hours=horizon)).alias("target_time") ) def add_lagged_features(df, electricity_load_history, horizon): """ diff --git a/content/python_files/next_horizon_prediction.py b/content/python_files/next_horizon_prediction.py index dd0ac57..e7c5353 100644 --- a/content/python_files/next_horizon_prediction.py +++ b/content/python_files/next_horizon_prediction.py @@ -25,9 +25,6 @@ from sklearn.ensemble import HistGradientBoostingRegressor -from extract_defs import extract_defs - - from tutorial_helpers import ( plot_lorenz_curve, @@ -36,68 +33,34 @@ plot_binned_residuals, collect_cv_predictions, ) - +from feature_engineering_lib import ( + time_range, + get_data_dir, + load_electricity_history_data, + resample, + get_X_y, + add_target_time, + add_lagged_features, + fetch_city_weather, + add_weather, + add_calendar_and_holidays, + add_features, +) # Ignore warnings from pkg_resources triggered by Python 3.13's multiprocessing. warnings.filterwarnings("ignore", category=UserWarning, module="pkg_resources") - -def load_feature_defs(): - source = Path(__file__).with_name("feature_engineering.py").read_text(encoding="utf-8") - namespace = {} - exec( - extract_defs( - source, - names=[ - "time_range", - "get_data_dir", - "load_electricity_history_data", - "resample", - "get_X_y", - "add_target_time", - "add_lagged_features", - "fetch_city_weather", - "add_weather", - "add_calendar_and_holidays", - "add_features", - ], - ), - namespace, - ) - return namespace - - -feature_defs = load_feature_defs() -time_range = feature_defs["time_range"] -load_electricity_history_data = feature_defs["load_electricity_history_data"] -resample = feature_defs["resample"] -get_X_y = feature_defs["get_X_y"] -add_features = feature_defs["add_features"] -fetch_city_weather = feature_defs["fetch_city_weather"] - -def load_or_cache(cache_file, builder): - cache_path = Path("results") / cache_file - cache_path.parent.mkdir(parents=True, exist_ok=True) - if cache_path.exists(): - with cache_path.open("rb") as f: - return cloudpickle.load(f) - obj = builder() - with cache_path.open("wb") as f: - cloudpickle.dump(obj, f) - return obj - # %% [markdown] # # For now, let's focus on the last horizon (1 hour) to train a model # predicting the electricity load at the next 1 hour. # %% -TIME_HORIZON = 1 # Focus on next step prediction -electricity_load_history = load_or_cache( - "electricity_load_history.pkl", - lambda: skrub.as_data_op(load_electricity_history_data).skb.set_name( - "load_electricity_load_data" - )().skb.apply_func(resample), +TIME_HORIZON = 1 # Focus on next step prediction +electricity_load_history = ( + skrub.as_data_op(load_electricity_history_data) + .skb.set_name("load_electricity_load_data")() + .skb.apply_func(resample) ) range_start = skrub.var("start", "2021-03-23") diff --git a/pixi.toml b/pixi.toml index 94b5d0a..910e902 100644 --- a/pixi.toml +++ b/pixi.toml @@ -52,11 +52,12 @@ python-libarchive-c = "*" execute-file-creating-pickle = { cmd = "python feature_engineering.py", cwd = "content/python_files" } create-notebooks-dir = { cmd = "mkdir -p ./content/notebooks" } copy-pickled-pipelines = { cmd = "cp ./content/python_files/*.pkl ./content/notebooks/", depends-on = ["create-notebooks-dir", "execute-file-creating-pickle"] } -copy-tutorial-helpers = { cmd = "cp ./content/python_files/tutorial_helpers.py ./content/notebooks/tutorial_helpers.py", depends-on = ["create-notebooks-dir"] } +scrape-function-definitions = { cmd = "python utils/extract_defs.py" } +copy-tutorial-helpers = { cmd = "cp ./content/python_files/tutorial_helpers.py ./content/notebooks/tutorial_helpers.py && cp ./content/python_files/*_lib.py ./content/notebooks/", depends-on = ["create-notebooks-dir", "scrape-function-definitions"] } copy-parallel-coordinates-plots = { cmd = "cp ./content/python_files/*.json ./content/notebooks/", depends-on = ["create-notebooks-dir"] } -convert-to-notebooks = { cmd = "jupytext --to notebook ./content/python_files/*.py && mv ./content/python_files/*.ipynb ./content/notebooks", depends-on = ["create-notebooks-dir", "copy-tutorial-helpers", "copy-parallel-coordinates-plots"] } -convert-to-executed-notebooks = { cmd = "jupytext --to notebook --execute ./content/python_files/*.py && mv ./content/python_files/*.ipynb ./content/notebooks", depends-on = ["create-notebooks-dir", "copy-pickled-pipelines", "copy-tutorial-helpers", "copy-parallel-coordinates-plots"] } -build-book = { cmd = "jupyter-book build book", depends-on = ["convert-to-notebooks", "copy-tutorial-helpers", "copy-parallel-coordinates-plots"] } +convert-to-notebooks = { cmd = "jupytext --to notebook ./content/python_files/*.py && rm ./content/python_files/*_lib.ipynb && mv ./content/python_files/*.ipynb ./content/notebooks", depends-on = ["create-notebooks-dir", "copy-tutorial-helpers", "scrape-function-definitions", "copy-parallel-coordinates-plots"] } +convert-to-executed-notebooks = { cmd = "jupytext --to notebook --execute ./content/python_files/*.py && mv ./content/python_files/*.ipynb ./content/notebooks", depends-on = ["create-notebooks-dir", "copy-pickled-pipelines", "copy-tutorial-helpers", "scrape-function-definitions", "copy-parallel-coordinates-plots"] } +build-book = { cmd = "jupyter-book build book", depends-on = ["convert-to-notebooks", "copy-tutorial-helpers", "scrape-function-definitions", "copy-parallel-coordinates-plots"] } build-jupyterlite = { cmd = "jupyter lite build --contents content --output-dir dist", cwd = "jupyterlite", depends-on = ["convert-to-notebooks"] } serve-jupyterlite = { cmd = "python -m http.server", cwd = "jupyterlite/dist", depends-on = ["build-jupyterlite"] } diff --git a/utils/extract_defs.py b/utils/extract_defs.py new file mode 100755 index 0000000..81bb61f --- /dev/null +++ b/utils/extract_defs.py @@ -0,0 +1,35 @@ +# /usr/bin/env python + +import argparse +import ast +import sys +from pathlib import Path + + +def extract_defs(source): + """Extract imports & function and class definitions, discarding other statements.""" + lines = source.splitlines(keepends=True) + tree = ast.parse(source) + import_snippets = [] + def_snippets = [] + for node in tree.body: + if isinstance(node, (ast.Import, ast.ImportFrom)): + import_snippets.append("".join(lines[node.lineno - 1 : node.end_lineno])) + elif isinstance(node, (ast.FunctionDef, ast.ClassDef)): + lineno = node.decorator_list[0].lineno if node.decorator_list else node.lineno + def_snippets.append("".join(lines[lineno - 1 : node.end_lineno])) + out = "".join(sorted(set(import_snippets))) + "\n\n" + "\n\n".join(def_snippets) + return out + + +if __name__ == "__main__": + repo = Path(__file__).parents[1] + for script_path in (repo / "content" / "python_files").glob("*.py"): + if script_path.stem.endswith("_lib"): + continue + source = script_path.read_text("utf-8") + output_path = script_path.with_stem(script_path.stem + "_lib") + output_path.write_text(extract_defs(source), "utf-8") + sys.stderr.write( + f"Extracted {script_path.relative_to(repo)} -> {output_path.relative_to(repo)}\n" + )