Skip to content

Commit ce468b7

Browse files
author
User
committed
update
1 parent 2b783e4 commit ce468b7

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

pytorch/plot_rl_rewards.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111

1212
print(f"average reward: {a.mean():.2f}, min: {a.min():.2f}, max: {a.max():.2f}")
1313

14-
plt.hist(a, bins=20)
14+
if args.mode == 'train':
15+
# show the training progress
16+
plt.plot(a)
17+
else:
18+
# test - show a histogram of rewards
19+
plt.hist(a, bins=20)
20+
1521
plt.title(args.mode)
1622
plt.show()

rl/plot_rl_rewards.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111

1212
print(f"average reward: {a.mean():.2f}, min: {a.min():.2f}, max: {a.max():.2f}")
1313

14-
plt.hist(a, bins=20)
14+
if args.mode == 'train':
15+
# show the training progress
16+
plt.plot(a)
17+
else:
18+
# test - show a histogram of rewards
19+
plt.hist(a, bins=20)
20+
1521
plt.title(args.mode)
1622
plt.show()

tf2.0/plot_rl_rewards.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111

1212
print(f"average reward: {a.mean():.2f}, min: {a.min():.2f}, max: {a.max():.2f}")
1313

14-
plt.hist(a, bins=20)
14+
if args.mode == 'train':
15+
# show the training progress
16+
plt.plot(a)
17+
else:
18+
# test - show a histogram of rewards
19+
plt.hist(a, bins=20)
20+
1521
plt.title(args.mode)
1622
plt.show()

0 commit comments

Comments
 (0)