|
1 | 1 | """
|
2 |
| -Script to generate a table comparing two run for MAE values for 48 hour 15 minute forecast |
| 2 | +Script to generate analysis of MAE values for multiple model forecasts |
| 3 | +
|
| 4 | +Does this for 48 hour horizon forecasts with 15 minute granularity |
| 5 | +
|
3 | 6 | """
|
4 | 7 |
|
5 | 8 | import argparse
|
|
10 | 13 | import wandb
|
11 | 14 |
|
12 | 15 |
|
13 |
| -def main(runs: list[str], run_names: list[str]) -> None: |
| 16 | +def main(project: str, runs: list[str], run_names: list[str]) -> None: |
14 | 17 | """
|
15 |
| - Compare two runs for MAE values for 48 hour 15 minute forecast |
| 18 | + Compare MAE values for multiple model forecasts for 48 hour horizon with 15 minute granularity |
| 19 | +
|
| 20 | + Args: |
| 21 | + project: name of W&B project |
| 22 | + runs: W&B ids of runs |
| 23 | + run_names: user specified names for runs |
| 24 | +
|
16 | 25 | """
|
17 | 26 | api = wandb.Api()
|
18 | 27 | dfs = []
|
19 | 28 | epoch_num = []
|
20 | 29 | for run in runs:
|
21 |
| - run = api.run(f"openclimatefix/PROJECT/{run}") |
| 30 | + run = api.run(f"openclimatefix/{project}/{run}") |
22 | 31 |
|
23 | 32 | df = run.history(samples=run.lastHistoryStep + 1)
|
24 | 33 | # Get the columns that are in the format 'MAE_horizon/step_<number>/val`
|
@@ -88,36 +97,41 @@ def main(runs: list[str], run_names: list[str]) -> None:
|
88 | 97 | for idx, df in enumerate(dfs):
|
89 | 98 | print(f"{run_names[idx]}: {df.mean()*100:0.3f}")
|
90 | 99 |
|
91 |
| - # Plot the error on per timestep, and all timesteps |
| 100 | + # Plot the error per timestep |
92 | 101 | plt.figure()
|
93 | 102 | for idx, df in enumerate(dfs):
|
94 |
| - plt.plot(column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}") |
| 103 | + plt.plot( |
| 104 | + column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}", linestyle="-" |
| 105 | + ) |
95 | 106 | plt.legend()
|
96 | 107 | plt.xlabel("Timestep (minutes)")
|
97 | 108 | plt.ylabel("MAE %")
|
98 | 109 | plt.title("MAE % for each timestep")
|
99 | 110 | plt.savefig("mae_per_timestep.png")
|
100 | 111 | plt.show()
|
101 | 112 |
|
102 |
| - # Plot the error on per timestep, and grouped timesteps |
| 113 | + # Plot the error per grouped timestep |
103 | 114 | plt.figure()
|
104 | 115 | for idx, run_name in enumerate(run_names):
|
105 |
| - plt.plot(groups_df[run_name], label=f"{run_name}, epoch: {epoch_num[idx]}") |
| 116 | + plt.plot( |
| 117 | + groups_df[run_name], |
| 118 | + label=f"{run_name}, epoch: {epoch_num[idx]}", |
| 119 | + marker="o", |
| 120 | + linestyle="-", |
| 121 | + ) |
106 | 122 | plt.legend()
|
107 | 123 | plt.xlabel("Timestep (minutes)")
|
108 | 124 | plt.ylabel("MAE %")
|
109 |
| - plt.title("MAE % for each timestep") |
110 |
| - plt.savefig("mae_per_timestep.png") |
| 125 | + plt.title("MAE % for each grouped timestep") |
| 126 | + plt.savefig("mae_per_grouped_timestep.png") |
111 | 127 | plt.show()
|
112 | 128 |
|
113 | 129 |
|
114 | 130 | if __name__ == "__main__":
|
115 | 131 | parser = argparse.ArgumentParser()
|
116 |
| - "5llq8iw6" |
117 |
| - parser.add_argument("--first_run", type=str, default="xdlew7ib") |
118 |
| - parser.add_argument("--second_run", type=str, default="v3mja33d") |
| 132 | + parser.add_argument("--project", type=str, default="") |
119 | 133 | # Add arguments that is a list of strings
|
120 | 134 | parser.add_argument("--list_of_runs", nargs="+")
|
121 | 135 | parser.add_argument("--run_names", nargs="+")
|
122 | 136 | args = parser.parse_args()
|
123 |
| - main(args.list_of_runs, args.run_names) |
| 137 | + main(args.project, args.list_of_runs, args.run_names) |
0 commit comments