|
| 1 | +""" |
| 2 | +Script to generate a table comparing two run for MAE values for 48 hour 15 minute forecast |
| 3 | +""" |
| 4 | + |
| 5 | +import argparse |
| 6 | + |
| 7 | +import matplotlib.pyplot as plt |
| 8 | +import numpy as np |
| 9 | +import wandb |
| 10 | + |
| 11 | + |
| 12 | +def main(runs: list[str], run_names: list[str]) -> None: |
| 13 | + """ |
| 14 | + Compare two runs for MAE values for 48 hour 15 minute forecast |
| 15 | + """ |
| 16 | + api = wandb.Api() |
| 17 | + dfs = [] |
| 18 | + for run in runs: |
| 19 | + run = api.run(f"openclimatefix/india/{run}") |
| 20 | + |
| 21 | + df = run.history() |
| 22 | + # Get the columns that are in the format 'MAE_horizon/step_<number>/val` |
| 23 | + mae_cols = [col for col in df.columns if "MAE_horizon/step_" in col and "val" in col] |
| 24 | + # Sort them |
| 25 | + mae_cols.sort() |
| 26 | + df = df[mae_cols] |
| 27 | + # Get last non-NaN value |
| 28 | + # Drop all rows with all NaNs |
| 29 | + df = df.dropna(how="all") |
| 30 | + # Select the last row |
| 31 | + # Get average across entire row, and get the IDX for the one with the smallest values |
| 32 | + min_row_mean = np.inf |
| 33 | + for idx, (row_idx, row) in enumerate(df.iterrows()): |
| 34 | + if row.mean() < min_row_mean: |
| 35 | + min_row_mean = row.mean() |
| 36 | + min_row_idx = idx |
| 37 | + df = df.iloc[min_row_idx] |
| 38 | + # Calculate the timedelta for each group |
| 39 | + # Get the step from the column name |
| 40 | + column_timesteps = [int(col.split("_")[-1].split("/")[0]) * 15 for col in mae_cols] |
| 41 | + dfs.append(df) |
| 42 | + # Get the timedelta for each group |
| 43 | + groupings = [ |
| 44 | + [0, 0], |
| 45 | + [15, 15], |
| 46 | + [30, 45], |
| 47 | + [45, 60], |
| 48 | + [60, 120], |
| 49 | + [120, 240], |
| 50 | + [240, 360], |
| 51 | + [360, 480], |
| 52 | + [480, 720], |
| 53 | + [720, 1440], |
| 54 | + [1440, 2880], |
| 55 | + ] |
| 56 | + header = "| Timestep |" |
| 57 | + separator = "| --- |" |
| 58 | + for run_name in run_names: |
| 59 | + header += f" {run_name} MAE % |" |
| 60 | + separator += " --- |" |
| 61 | + print(header) |
| 62 | + print(separator) |
| 63 | + for grouping in groupings: |
| 64 | + group_string = f"| {grouping[0]}-{grouping[1]} minutes |" |
| 65 | + # Select indicies from column_timesteps that are within the grouping, inclusive |
| 66 | + group_idx = [ |
| 67 | + idx |
| 68 | + for idx, timestep in enumerate(column_timesteps) |
| 69 | + if timestep >= grouping[0] and timestep <= grouping[1] |
| 70 | + ] |
| 71 | + for df in dfs: |
| 72 | + group_string += f" {df.iloc[group_idx].mean()*100.:0.3f} |" |
| 73 | + print(group_string) |
| 74 | + |
| 75 | + # Plot the error on per timestep, and grouped timesteps |
| 76 | + plt.figure() |
| 77 | + for idx, df in enumerate(dfs): |
| 78 | + plt.plot(column_timesteps, df, label=run_names[idx]) |
| 79 | + plt.legend() |
| 80 | + plt.xlabel("Timestep (minutes)") |
| 81 | + plt.ylabel("MAE %") |
| 82 | + plt.title("MAE % for each timestep") |
| 83 | + plt.savefig("mae_per_timestep.png") |
| 84 | + plt.show() |
| 85 | + |
| 86 | + |
| 87 | +if __name__ == "__main__": |
| 88 | + parser = argparse.ArgumentParser() |
| 89 | + "5llq8iw6" |
| 90 | + parser.add_argument("--first_run", type=str, default="xdlew7ib") |
| 91 | + parser.add_argument("--second_run", type=str, default="v3mja33d") |
| 92 | + # Add arguments that is a list of strings |
| 93 | + parser.add_argument("--list_of_runs", nargs="+") |
| 94 | + parser.add_argument("--run_names", nargs="+") |
| 95 | + args = parser.parse_args() |
| 96 | + main(args.list_of_runs, args.run_names) |
0 commit comments