Skip to content

Commit acb6a36

Browse files
authored
Update MAE analysis script (#274)
Update script
1 parent 9497430 commit acb6a36

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

experiments/analysis.py experiments/mae_analysis.py

+28-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""
2-
Script to generate a table comparing two run for MAE values for 48 hour 15 minute forecast
2+
Script to generate analysis of MAE values for multiple model forecasts
3+
4+
Does this for 48 hour horizon forecasts with 15 minute granularity
5+
36
"""
47

58
import argparse
@@ -10,15 +13,21 @@
1013
import wandb
1114

1215

13-
def main(runs: list[str], run_names: list[str]) -> None:
16+
def main(project: str, runs: list[str], run_names: list[str]) -> None:
1417
"""
15-
Compare two runs for MAE values for 48 hour 15 minute forecast
18+
Compare MAE values for multiple model forecasts for 48 hour horizon with 15 minute granularity
19+
20+
Args:
21+
project: name of W&B project
22+
runs: W&B ids of runs
23+
run_names: user specified names for runs
24+
1625
"""
1726
api = wandb.Api()
1827
dfs = []
1928
epoch_num = []
2029
for run in runs:
21-
run = api.run(f"openclimatefix/PROJECT/{run}")
30+
run = api.run(f"openclimatefix/{project}/{run}")
2231

2332
df = run.history(samples=run.lastHistoryStep + 1)
2433
# Get the columns that are in the format 'MAE_horizon/step_<number>/val`
@@ -88,36 +97,41 @@ def main(runs: list[str], run_names: list[str]) -> None:
8897
for idx, df in enumerate(dfs):
8998
print(f"{run_names[idx]}: {df.mean()*100:0.3f}")
9099

91-
# Plot the error on per timestep, and all timesteps
100+
# Plot the error per timestep
92101
plt.figure()
93102
for idx, df in enumerate(dfs):
94-
plt.plot(column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}")
103+
plt.plot(
104+
column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}", linestyle="-"
105+
)
95106
plt.legend()
96107
plt.xlabel("Timestep (minutes)")
97108
plt.ylabel("MAE %")
98109
plt.title("MAE % for each timestep")
99110
plt.savefig("mae_per_timestep.png")
100111
plt.show()
101112

102-
# Plot the error on per timestep, and grouped timesteps
113+
# Plot the error per grouped timestep
103114
plt.figure()
104115
for idx, run_name in enumerate(run_names):
105-
plt.plot(groups_df[run_name], label=f"{run_name}, epoch: {epoch_num[idx]}")
116+
plt.plot(
117+
groups_df[run_name],
118+
label=f"{run_name}, epoch: {epoch_num[idx]}",
119+
marker="o",
120+
linestyle="-",
121+
)
106122
plt.legend()
107123
plt.xlabel("Timestep (minutes)")
108124
plt.ylabel("MAE %")
109-
plt.title("MAE % for each timestep")
110-
plt.savefig("mae_per_timestep.png")
125+
plt.title("MAE % for each grouped timestep")
126+
plt.savefig("mae_per_grouped_timestep.png")
111127
plt.show()
112128

113129

114130
if __name__ == "__main__":
115131
parser = argparse.ArgumentParser()
116-
"5llq8iw6"
117-
parser.add_argument("--first_run", type=str, default="xdlew7ib")
118-
parser.add_argument("--second_run", type=str, default="v3mja33d")
132+
parser.add_argument("--project", type=str, default="")
119133
# Add arguments that is a list of strings
120134
parser.add_argument("--list_of_runs", nargs="+")
121135
parser.add_argument("--run_names", nargs="+")
122136
args = parser.parse_args()
123-
main(args.list_of_runs, args.run_names)
137+
main(args.project, args.list_of_runs, args.run_names)

0 commit comments

Comments
 (0)