Skip to content

Commit 28dbed3

Browse files
committed
Backup for the old laptop
1 parent 384dd6b commit 28dbed3

File tree

2 files changed

+365
-0
lines changed

2 files changed

+365
-0
lines changed

viz/ablation_study.py

+258
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
import os
2+
from typing import Any, List, Tuple
3+
4+
from eaf import (
5+
get_empirical_attainment_surface,
6+
EmpiricalAttainmentFuncPlot,
7+
)
8+
9+
import matplotlib.pyplot as plt
10+
11+
import numpy as np
12+
13+
from constants_for_ablation import (
14+
BENCH_NAMES,
15+
CMAP,
16+
COLOR_LABEL_DICT,
17+
COSTS_SHAPE,
18+
DATASET_NAMES,
19+
LARGER_IS_BETTER_DICT,
20+
LEVELS,
21+
LINESTYLES_DICT,
22+
LOGSCALE_DICT,
23+
MARKER_DICT,
24+
N_SAMPLES,
25+
NAME_DICT,
26+
OBJ_LABEL_DICT,
27+
OBJ_NAMES_DICT,
28+
)
29+
from utils import get_costs, get_true_pareto_front_and_ref_point
30+
31+
32+
plt.rcParams["font.family"] = "Times New Roman"
33+
plt.rcParams["font.size"] = 18
34+
plt.rcParams["mathtext.fontset"] = "stix" # The setting of math font
35+
PLOT_CHECK_MODE = False
36+
HV_MODE = True
37+
38+
39+
def disable_axis_label(
40+
ax: plt.Axes,
41+
set_xlabel: bool = False,
42+
set_ylabel: bool = False,
43+
) -> None:
44+
if not set_xlabel:
45+
ax.set_xlabel("")
46+
if not set_ylabel:
47+
ax.set_ylabel("")
48+
49+
50+
def plot_eaf(
51+
ax: plt.Axes,
52+
eaf_plot: EmpiricalAttainmentFuncPlot,
53+
obj_names: List[str],
54+
dataset_name: str,
55+
set_xlabel: bool,
56+
set_ylabel: bool,
57+
**kwargs
58+
) -> Tuple[List[Any], List[str]]:
59+
surfs_list, colors, labels, markers, linestyles = [], [], [], [], []
60+
for opt_name, color_label in COLOR_LABEL_DICT.items():
61+
color, label = color_label
62+
colors.append(color)
63+
labels.append(label)
64+
linestyles.append(LINESTYLES_DICT[opt_name])
65+
markers.append(MARKER_DICT[opt_name])
66+
costs = get_costs(obj_names, dataset_name, opt_name)
67+
surfs = get_empirical_attainment_surface(costs.copy(), levels=LEVELS, **kwargs)
68+
surfs_list.append(surfs)
69+
else:
70+
plot_kwargs = dict(
71+
colors=colors, labels=labels, linestyles=linestyles, markers=markers, markersize=3, alpha=0.5
72+
)
73+
lines = eaf_plot.plot_multiple_surface_with_band(ax, surfs_list=surfs_list, **plot_kwargs)
74+
label = "True Pareto front"
75+
lines.append(eaf_plot.plot_true_pareto_surface(
76+
ax, color="black", label=label, linestyle="--", marker="*", alpha=0.2
77+
))
78+
lines = lines[-1:]
79+
# labels.append(label)
80+
labels = labels[-1:]
81+
82+
if set_xlabel:
83+
ax.set_xlabel(OBJ_LABEL_DICT[obj_names[0]])
84+
if set_ylabel:
85+
ax.set_ylabel(OBJ_LABEL_DICT[obj_names[1]])
86+
return lines, labels
87+
88+
89+
def plot_hv(
90+
ax: plt.Axes,
91+
eaf_plot: EmpiricalAttainmentFuncPlot,
92+
obj_names: List[str],
93+
dataset_name: str,
94+
log: bool,
95+
set_xlabel: bool,
96+
set_ylabel: bool,
97+
**kwargs
98+
) -> Tuple[List[Any], List[str]]:
99+
100+
n_opts = len(COLOR_LABEL_DICT)
101+
costs_array, colors, labels, markers, linestyles = np.empty((n_opts, *COSTS_SHAPE)), [], [], [], []
102+
for idx, (opt_name, color_label) in enumerate(COLOR_LABEL_DICT.items()):
103+
color, label = color_label
104+
colors.append(color)
105+
labels.append(label)
106+
markers.append(MARKER_DICT[opt_name])
107+
linestyles.append(LINESTYLES_DICT[opt_name])
108+
costs_array[idx] = get_costs(obj_names, dataset_name, opt_name)
109+
else:
110+
label = "True Pareto front"
111+
plot_kwargs = dict(colors=colors, labels=labels, markers=markers, markevery=5, linestyles=linestyles)
112+
lines = eaf_plot.plot_multiple_hypervolume2d_with_band(ax, costs_array, log=log, **plot_kwargs)
113+
lines.append(
114+
eaf_plot.plot_true_pareto_surface_hypervolume2d(
115+
ax, n_observations=N_SAMPLES, color="black", label=label, linestyle="--"
116+
)
117+
)
118+
labels.append(label)
119+
120+
# Hack
121+
lines = lines[-1:]
122+
labels = labels[-1:]
123+
ax.set_ylim(ymin=0.88, ymax=1.01)
124+
125+
disable_axis_label(ax, set_xlabel=set_xlabel, set_ylabel=set_ylabel)
126+
return lines, labels
127+
128+
129+
def add_colorbar(fig: plt.Figure, axes: List[List[plt.Axes]]) -> None:
130+
ZEROS = np.ones((2, 2))
131+
levels = np.linspace(1.5, 5.0, 8)
132+
try:
133+
ax = axes[0][0]
134+
except TypeError:
135+
ax = axes[0]
136+
137+
cb = ax.contourf(ZEROS, ZEROS, ZEROS + 5, levels=levels, cmap=CMAP)
138+
cbar = fig.colorbar(cb, ax=axes.ravel().tolist(), pad=0.025)
139+
cbar.ax.set_title("$\\eta$", y=1.01)
140+
141+
142+
def plot(
143+
ax: plt.Axes,
144+
bench_id: int,
145+
data_id: int,
146+
hv_mode: bool,
147+
set_xlabel: bool,
148+
set_ylabel: bool,
149+
) -> List[Any]:
150+
bench_name = BENCH_NAMES[bench_id]
151+
dataset_name = DATASET_NAMES[bench_name][data_id]
152+
obj_names = OBJ_NAMES_DICT[bench_name]
153+
kwargs = dict(
154+
larger_is_better_objectives=LARGER_IS_BETTER_DICT[bench_name],
155+
log_scale=LOGSCALE_DICT[bench_name],
156+
)
157+
158+
true_pf, ref_point = get_true_pareto_front_and_ref_point(obj_names, bench_name, dataset_name)
159+
eaf_plot = EmpiricalAttainmentFuncPlot(true_pareto_sols=true_pf, ref_point=ref_point, **kwargs)
160+
kwargs.update(set_xlabel=set_xlabel, set_ylabel=set_ylabel)
161+
if hv_mode:
162+
lines, labels = plot_hv(ax, eaf_plot, obj_names, dataset_name, log=False, **kwargs)
163+
else:
164+
lines, labels = plot_eaf(ax, eaf_plot, obj_names, dataset_name, **kwargs)
165+
166+
ax.set_title(NAME_DICT[dataset_name])
167+
ax.grid(which="minor", color="gray", linestyle=":")
168+
ax.grid(which="major", color="black")
169+
170+
return lines, labels
171+
172+
173+
def plot_hv_for_hpolib(subplots_kwargs, legend_kwargs, hv_mode: bool) -> None:
174+
fig, axes = plt.subplots(**subplots_kwargs)
175+
for data_id in range(4):
176+
r, c = data_id // 2, data_id % 2
177+
set_xlabel = r == 1
178+
set_ylabel = c == 0
179+
kwargs = dict(
180+
set_xlabel=set_xlabel,
181+
set_ylabel=set_ylabel,
182+
bench_id=0,
183+
data_id=data_id,
184+
hv_mode=hv_mode,
185+
)
186+
lines, labels = plot(axes[r][c], **kwargs)
187+
else:
188+
axes[-1][0].legend(handles=lines, labels=labels, ncol=(len(labels) + 1) // 2, **legend_kwargs)
189+
190+
add_colorbar(fig, axes)
191+
192+
if hv_mode:
193+
if PLOT_CHECK_MODE:
194+
plt.show()
195+
else:
196+
plt.savefig("figs/hv2d-hpolib-ablation.png", bbox_inches='tight')
197+
else:
198+
if PLOT_CHECK_MODE:
199+
plt.show()
200+
else:
201+
plt.savefig("figs/eaf-hpolib-ablation.png", bbox_inches='tight')
202+
203+
204+
def plot_hv_for_nmt(subplots_kwargs, legend_kwargs, hv_mode: bool) -> None:
205+
fig, axes = plt.subplots(**subplots_kwargs)
206+
for data_id in range(3):
207+
set_ylabel = data_id == 0
208+
kwargs = dict(
209+
set_xlabel=True,
210+
set_ylabel=set_ylabel,
211+
bench_id=1,
212+
data_id=data_id,
213+
hv_mode=hv_mode,
214+
)
215+
lines, labels = plot(axes[data_id], **kwargs)
216+
else:
217+
axes[1].legend(handles=lines, labels=labels, ncol=(len(labels) + 1) // 2, **legend_kwargs)
218+
219+
add_colorbar(fig, axes)
220+
221+
if hv_mode:
222+
if PLOT_CHECK_MODE:
223+
plt.show()
224+
else:
225+
plt.savefig("figs/hv2d-nmt-ablation.png", bbox_inches='tight')
226+
else:
227+
if PLOT_CHECK_MODE:
228+
plt.show()
229+
else:
230+
plt.savefig("figs/eaf-nmt-ablation.png", bbox_inches='tight')
231+
232+
233+
if __name__ == "__main__":
234+
os.makedirs("figs/", exist_ok=True)
235+
subplots_kwargs = dict(
236+
nrows=2,
237+
ncols=2,
238+
sharex=HV_MODE,
239+
sharey=HV_MODE,
240+
figsize=(20, 10),
241+
gridspec_kw=dict(
242+
wspace=0.03 if HV_MODE else 0.09,
243+
hspace=0.125 if HV_MODE else 0.2,
244+
)
245+
)
246+
legend_kwargs = dict(
247+
loc='upper center',
248+
fontsize=20,
249+
bbox_to_anchor=(1.0, -0.16) if HV_MODE else (1.03, -0.16),
250+
fancybox=False,
251+
shadow=False,
252+
)
253+
plot_hv_for_hpolib(subplots_kwargs, legend_kwargs, hv_mode=HV_MODE)
254+
255+
subplots_kwargs.pop("nrows")
256+
subplots_kwargs.update(ncols=3, figsize=(15, 3) if HV_MODE else (20, 3.5))
257+
legend_kwargs.update(bbox_to_anchor=(0.5, -0.3) if HV_MODE else (0.5, -0.22), fontsize=18)
258+
plot_hv_for_nmt(subplots_kwargs, legend_kwargs, hv_mode=HV_MODE)

viz/constants_for_ablation.py

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
4+
5+
CMAP = plt.get_cmap("rainbow")
6+
7+
8+
N_SAMPLES = 100
9+
N_INIT = N_SAMPLES * 5 // 100 # N_INIT = 5
10+
N_RUNS = 20
11+
N_OBJ = 2
12+
COSTS_SHAPE = (N_RUNS, N_SAMPLES, N_OBJ)
13+
# LEVELS = [N_RUNS // 4, N_RUNS // 2, (3 * N_RUNS) // 4]
14+
LEVELS = [N_RUNS // 2] * 3
15+
Q, DF = [0.10, 0.15][0], [1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0]
16+
COLORS = [CMAP(v) for v in np.linspace(0, 1, len(DF))]
17+
NO_WAMRSTART = ["no-warmstart-", ""][1]
18+
META_LEARN_TPE = [f"{NO_WAMRSTART}tpe_q={Q:.2f}_df={df:.1f}" for df in DF]
19+
NORMAL_TPE = f"normal_tpe_q={Q:.2f}"
20+
HPOLIB = "hpolib"
21+
NMT = "nmt"
22+
HPOBENCH = "hpobench"
23+
NAVAL = "naval_propulsion"
24+
PARKINSONS = "parkinsons_telemonitoring"
25+
PROTEIN = "protein_structure"
26+
SLICE = "slice_localization"
27+
SO = "so_en"
28+
SW = "sw_en"
29+
TL = "tl_en"
30+
31+
COLOR_LABEL_DICT = {
32+
name: (color, f"$\\eta = {name.split('df=')[-1]}$")
33+
for color, name in zip(COLORS, META_LEARN_TPE)
34+
}
35+
LINESTYLES_DICT = {
36+
name: "solid"
37+
for name in META_LEARN_TPE
38+
}
39+
MARKER_DICT = {
40+
name: "*"
41+
for name in META_LEARN_TPE
42+
}
43+
BENCH_NAMES = [HPOLIB, NMT, HPOBENCH]
44+
DATASET_NAMES = {
45+
HPOLIB: [
46+
NAVAL,
47+
PARKINSONS,
48+
PROTEIN,
49+
SLICE,
50+
],
51+
NMT: [
52+
SO,
53+
SW,
54+
TL,
55+
],
56+
HPOBENCH: [
57+
"credit_g",
58+
"vehicle",
59+
"kc1",
60+
"phoneme",
61+
"blood_transfusion",
62+
"australian",
63+
"car",
64+
"segment",
65+
]
66+
}
67+
NAME_DICT = {
68+
NAVAL: "Naval Propulsion",
69+
PARKINSONS: "Parkinsons Telemonitoring",
70+
PROTEIN: "Protein Structure",
71+
SLICE: "Slice Localization",
72+
SO: "Somali to English",
73+
SW: "Swahili to English",
74+
TL: "Tagalog to English",
75+
}
76+
77+
RUNTIME = "runtime"
78+
DECODING_TIME = "decoding_time"
79+
VALID_MSE = "valid_mse"
80+
BLEU = "bleu"
81+
OBJ_NAMES_DICT = {
82+
HPOLIB: [RUNTIME, VALID_MSE],
83+
NMT: [DECODING_TIME, BLEU],
84+
HPOBENCH: ["precision", "bal_acc"]
85+
}
86+
OBJ_LABEL_DICT = {
87+
RUNTIME: "Runtime",
88+
VALID_MSE: "Validation MSE",
89+
DECODING_TIME: "Runtime",
90+
BLEU: "BLEU",
91+
}
92+
LARGER_IS_BETTER_DICT = {
93+
HPOLIB: None,
94+
NMT: [1],
95+
HPOBENCH: [0, 1],
96+
}
97+
LOGSCALE_DICT = {
98+
HPOLIB: [1],
99+
NMT: None,
100+
HPOBENCH: None,
101+
}
102+
TICK_PARAMS = dict(
103+
labelleft=False,
104+
labelbottom=False,
105+
left=False,
106+
bottom=False,
107+
)

0 commit comments

Comments
 (0)