Skip to content

Commit e58646b

Browse files
committedOct 22, 2021
Improve reward curve plot
1 parent dd5bc67 commit e58646b

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed
 

‎src/reward_preprocessing/plot_reward_curves.py

+4
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def plot_reward_curves(
113113
for row, (model_name, model_versions) in enumerate(models.items()):
114114
for col, (objective, model) in enumerate(model_versions.items()):
115115
ax[row, col].axhline(0, linewidth=0.2, color="black")
116+
for x in [200, 400, 600, 800]:
117+
ax[row, col].axvline(x, linewidth=0.2, color="gray")
116118
predicted_rewards = np.array([])
117119
for transitions_batch in dataloader:
118120
with torch.no_grad():
@@ -134,6 +136,8 @@ def plot_reward_curves(
134136
predicted_rewards,
135137
linewidth=0.4
136138
)
139+
ax[row, col].set_ylim(top=8, bottom=-6)
140+
ax[row, col].set_xlim(left=0, right=1000)
137141

138142
ax[row, col].set(
139143
title=f"{model_name} / {PRETTY_OBJECTIVE_NAMES[objective]}"

0 commit comments

Comments
 (0)