|
| 1 | +import os |
| 2 | +from typing import Any, List, Tuple |
| 3 | + |
| 4 | +from eaf import ( |
| 5 | + get_empirical_attainment_surface, |
| 6 | + EmpiricalAttainmentFuncPlot, |
| 7 | +) |
| 8 | + |
| 9 | +import matplotlib.pyplot as plt |
| 10 | + |
| 11 | +import numpy as np |
| 12 | + |
| 13 | +from constants_for_ablation import ( |
| 14 | + BENCH_NAMES, |
| 15 | + CMAP, |
| 16 | + COLOR_LABEL_DICT, |
| 17 | + COSTS_SHAPE, |
| 18 | + DATASET_NAMES, |
| 19 | + LARGER_IS_BETTER_DICT, |
| 20 | + LEVELS, |
| 21 | + LINESTYLES_DICT, |
| 22 | + LOGSCALE_DICT, |
| 23 | + MARKER_DICT, |
| 24 | + N_SAMPLES, |
| 25 | + NAME_DICT, |
| 26 | + OBJ_LABEL_DICT, |
| 27 | + OBJ_NAMES_DICT, |
| 28 | +) |
| 29 | +from utils import get_costs, get_true_pareto_front_and_ref_point |
| 30 | + |
| 31 | + |
| 32 | +plt.rcParams["font.family"] = "Times New Roman" |
| 33 | +plt.rcParams["font.size"] = 18 |
| 34 | +plt.rcParams["mathtext.fontset"] = "stix" # The setting of math font |
| 35 | +PLOT_CHECK_MODE = False |
| 36 | +HV_MODE = True |
| 37 | + |
| 38 | + |
| 39 | +def disable_axis_label( |
| 40 | + ax: plt.Axes, |
| 41 | + set_xlabel: bool = False, |
| 42 | + set_ylabel: bool = False, |
| 43 | +) -> None: |
| 44 | + if not set_xlabel: |
| 45 | + ax.set_xlabel("") |
| 46 | + if not set_ylabel: |
| 47 | + ax.set_ylabel("") |
| 48 | + |
| 49 | + |
| 50 | +def plot_eaf( |
| 51 | + ax: plt.Axes, |
| 52 | + eaf_plot: EmpiricalAttainmentFuncPlot, |
| 53 | + obj_names: List[str], |
| 54 | + dataset_name: str, |
| 55 | + set_xlabel: bool, |
| 56 | + set_ylabel: bool, |
| 57 | + **kwargs |
| 58 | +) -> Tuple[List[Any], List[str]]: |
| 59 | + surfs_list, colors, labels, markers, linestyles = [], [], [], [], [] |
| 60 | + for opt_name, color_label in COLOR_LABEL_DICT.items(): |
| 61 | + color, label = color_label |
| 62 | + colors.append(color) |
| 63 | + labels.append(label) |
| 64 | + linestyles.append(LINESTYLES_DICT[opt_name]) |
| 65 | + markers.append(MARKER_DICT[opt_name]) |
| 66 | + costs = get_costs(obj_names, dataset_name, opt_name) |
| 67 | + surfs = get_empirical_attainment_surface(costs.copy(), levels=LEVELS, **kwargs) |
| 68 | + surfs_list.append(surfs) |
| 69 | + else: |
| 70 | + plot_kwargs = dict( |
| 71 | + colors=colors, labels=labels, linestyles=linestyles, markers=markers, markersize=3, alpha=0.5 |
| 72 | + ) |
| 73 | + lines = eaf_plot.plot_multiple_surface_with_band(ax, surfs_list=surfs_list, **plot_kwargs) |
| 74 | + label = "True Pareto front" |
| 75 | + lines.append(eaf_plot.plot_true_pareto_surface( |
| 76 | + ax, color="black", label=label, linestyle="--", marker="*", alpha=0.2 |
| 77 | + )) |
| 78 | + lines = lines[-1:] |
| 79 | + # labels.append(label) |
| 80 | + labels = labels[-1:] |
| 81 | + |
| 82 | + if set_xlabel: |
| 83 | + ax.set_xlabel(OBJ_LABEL_DICT[obj_names[0]]) |
| 84 | + if set_ylabel: |
| 85 | + ax.set_ylabel(OBJ_LABEL_DICT[obj_names[1]]) |
| 86 | + return lines, labels |
| 87 | + |
| 88 | + |
| 89 | +def plot_hv( |
| 90 | + ax: plt.Axes, |
| 91 | + eaf_plot: EmpiricalAttainmentFuncPlot, |
| 92 | + obj_names: List[str], |
| 93 | + dataset_name: str, |
| 94 | + log: bool, |
| 95 | + set_xlabel: bool, |
| 96 | + set_ylabel: bool, |
| 97 | + **kwargs |
| 98 | +) -> Tuple[List[Any], List[str]]: |
| 99 | + |
| 100 | + n_opts = len(COLOR_LABEL_DICT) |
| 101 | + costs_array, colors, labels, markers, linestyles = np.empty((n_opts, *COSTS_SHAPE)), [], [], [], [] |
| 102 | + for idx, (opt_name, color_label) in enumerate(COLOR_LABEL_DICT.items()): |
| 103 | + color, label = color_label |
| 104 | + colors.append(color) |
| 105 | + labels.append(label) |
| 106 | + markers.append(MARKER_DICT[opt_name]) |
| 107 | + linestyles.append(LINESTYLES_DICT[opt_name]) |
| 108 | + costs_array[idx] = get_costs(obj_names, dataset_name, opt_name) |
| 109 | + else: |
| 110 | + label = "True Pareto front" |
| 111 | + plot_kwargs = dict(colors=colors, labels=labels, markers=markers, markevery=5, linestyles=linestyles) |
| 112 | + lines = eaf_plot.plot_multiple_hypervolume2d_with_band(ax, costs_array, log=log, **plot_kwargs) |
| 113 | + lines.append( |
| 114 | + eaf_plot.plot_true_pareto_surface_hypervolume2d( |
| 115 | + ax, n_observations=N_SAMPLES, color="black", label=label, linestyle="--" |
| 116 | + ) |
| 117 | + ) |
| 118 | + labels.append(label) |
| 119 | + |
| 120 | + # Hack |
| 121 | + lines = lines[-1:] |
| 122 | + labels = labels[-1:] |
| 123 | + ax.set_ylim(ymin=0.88, ymax=1.01) |
| 124 | + |
| 125 | + disable_axis_label(ax, set_xlabel=set_xlabel, set_ylabel=set_ylabel) |
| 126 | + return lines, labels |
| 127 | + |
| 128 | + |
| 129 | +def add_colorbar(fig: plt.Figure, axes: List[List[plt.Axes]]) -> None: |
| 130 | + ZEROS = np.ones((2, 2)) |
| 131 | + levels = np.linspace(1.5, 5.0, 8) |
| 132 | + try: |
| 133 | + ax = axes[0][0] |
| 134 | + except TypeError: |
| 135 | + ax = axes[0] |
| 136 | + |
| 137 | + cb = ax.contourf(ZEROS, ZEROS, ZEROS + 5, levels=levels, cmap=CMAP) |
| 138 | + cbar = fig.colorbar(cb, ax=axes.ravel().tolist(), pad=0.025) |
| 139 | + cbar.ax.set_title("$\\eta$", y=1.01) |
| 140 | + |
| 141 | + |
| 142 | +def plot( |
| 143 | + ax: plt.Axes, |
| 144 | + bench_id: int, |
| 145 | + data_id: int, |
| 146 | + hv_mode: bool, |
| 147 | + set_xlabel: bool, |
| 148 | + set_ylabel: bool, |
| 149 | +) -> List[Any]: |
| 150 | + bench_name = BENCH_NAMES[bench_id] |
| 151 | + dataset_name = DATASET_NAMES[bench_name][data_id] |
| 152 | + obj_names = OBJ_NAMES_DICT[bench_name] |
| 153 | + kwargs = dict( |
| 154 | + larger_is_better_objectives=LARGER_IS_BETTER_DICT[bench_name], |
| 155 | + log_scale=LOGSCALE_DICT[bench_name], |
| 156 | + ) |
| 157 | + |
| 158 | + true_pf, ref_point = get_true_pareto_front_and_ref_point(obj_names, bench_name, dataset_name) |
| 159 | + eaf_plot = EmpiricalAttainmentFuncPlot(true_pareto_sols=true_pf, ref_point=ref_point, **kwargs) |
| 160 | + kwargs.update(set_xlabel=set_xlabel, set_ylabel=set_ylabel) |
| 161 | + if hv_mode: |
| 162 | + lines, labels = plot_hv(ax, eaf_plot, obj_names, dataset_name, log=False, **kwargs) |
| 163 | + else: |
| 164 | + lines, labels = plot_eaf(ax, eaf_plot, obj_names, dataset_name, **kwargs) |
| 165 | + |
| 166 | + ax.set_title(NAME_DICT[dataset_name]) |
| 167 | + ax.grid(which="minor", color="gray", linestyle=":") |
| 168 | + ax.grid(which="major", color="black") |
| 169 | + |
| 170 | + return lines, labels |
| 171 | + |
| 172 | + |
| 173 | +def plot_hv_for_hpolib(subplots_kwargs, legend_kwargs, hv_mode: bool) -> None: |
| 174 | + fig, axes = plt.subplots(**subplots_kwargs) |
| 175 | + for data_id in range(4): |
| 176 | + r, c = data_id // 2, data_id % 2 |
| 177 | + set_xlabel = r == 1 |
| 178 | + set_ylabel = c == 0 |
| 179 | + kwargs = dict( |
| 180 | + set_xlabel=set_xlabel, |
| 181 | + set_ylabel=set_ylabel, |
| 182 | + bench_id=0, |
| 183 | + data_id=data_id, |
| 184 | + hv_mode=hv_mode, |
| 185 | + ) |
| 186 | + lines, labels = plot(axes[r][c], **kwargs) |
| 187 | + else: |
| 188 | + axes[-1][0].legend(handles=lines, labels=labels, ncol=(len(labels) + 1) // 2, **legend_kwargs) |
| 189 | + |
| 190 | + add_colorbar(fig, axes) |
| 191 | + |
| 192 | + if hv_mode: |
| 193 | + if PLOT_CHECK_MODE: |
| 194 | + plt.show() |
| 195 | + else: |
| 196 | + plt.savefig("figs/hv2d-hpolib-ablation.png", bbox_inches='tight') |
| 197 | + else: |
| 198 | + if PLOT_CHECK_MODE: |
| 199 | + plt.show() |
| 200 | + else: |
| 201 | + plt.savefig("figs/eaf-hpolib-ablation.png", bbox_inches='tight') |
| 202 | + |
| 203 | + |
| 204 | +def plot_hv_for_nmt(subplots_kwargs, legend_kwargs, hv_mode: bool) -> None: |
| 205 | + fig, axes = plt.subplots(**subplots_kwargs) |
| 206 | + for data_id in range(3): |
| 207 | + set_ylabel = data_id == 0 |
| 208 | + kwargs = dict( |
| 209 | + set_xlabel=True, |
| 210 | + set_ylabel=set_ylabel, |
| 211 | + bench_id=1, |
| 212 | + data_id=data_id, |
| 213 | + hv_mode=hv_mode, |
| 214 | + ) |
| 215 | + lines, labels = plot(axes[data_id], **kwargs) |
| 216 | + else: |
| 217 | + axes[1].legend(handles=lines, labels=labels, ncol=(len(labels) + 1) // 2, **legend_kwargs) |
| 218 | + |
| 219 | + add_colorbar(fig, axes) |
| 220 | + |
| 221 | + if hv_mode: |
| 222 | + if PLOT_CHECK_MODE: |
| 223 | + plt.show() |
| 224 | + else: |
| 225 | + plt.savefig("figs/hv2d-nmt-ablation.png", bbox_inches='tight') |
| 226 | + else: |
| 227 | + if PLOT_CHECK_MODE: |
| 228 | + plt.show() |
| 229 | + else: |
| 230 | + plt.savefig("figs/eaf-nmt-ablation.png", bbox_inches='tight') |
| 231 | + |
| 232 | + |
| 233 | +if __name__ == "__main__": |
| 234 | + os.makedirs("figs/", exist_ok=True) |
| 235 | + subplots_kwargs = dict( |
| 236 | + nrows=2, |
| 237 | + ncols=2, |
| 238 | + sharex=HV_MODE, |
| 239 | + sharey=HV_MODE, |
| 240 | + figsize=(20, 10), |
| 241 | + gridspec_kw=dict( |
| 242 | + wspace=0.03 if HV_MODE else 0.09, |
| 243 | + hspace=0.125 if HV_MODE else 0.2, |
| 244 | + ) |
| 245 | + ) |
| 246 | + legend_kwargs = dict( |
| 247 | + loc='upper center', |
| 248 | + fontsize=20, |
| 249 | + bbox_to_anchor=(1.0, -0.16) if HV_MODE else (1.03, -0.16), |
| 250 | + fancybox=False, |
| 251 | + shadow=False, |
| 252 | + ) |
| 253 | + plot_hv_for_hpolib(subplots_kwargs, legend_kwargs, hv_mode=HV_MODE) |
| 254 | + |
| 255 | + subplots_kwargs.pop("nrows") |
| 256 | + subplots_kwargs.update(ncols=3, figsize=(15, 3) if HV_MODE else (20, 3.5)) |
| 257 | + legend_kwargs.update(bbox_to_anchor=(0.5, -0.3) if HV_MODE else (0.5, -0.22), fontsize=18) |
| 258 | + plot_hv_for_nmt(subplots_kwargs, legend_kwargs, hv_mode=HV_MODE) |
0 commit comments