-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathrun_train_ai.py
37 lines (30 loc) · 1.04 KB
/
run_train_ai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import gym
import gym_tetris
from statistics import mean, median
from gym_tetris.ai.QNetwork import QNetwork
def main():
env = gym.make("tetris-v1", action_mode=1)
network = QNetwork()
network.load()
running = True
total_games = 0
total_steps = 0
while running:
steps, rewards, scores = network.train(env, episodes=25)
total_games += len(scores)
total_steps += steps
network.save()
print("==================")
print("* Total Games: ", total_games)
print("* Total Steps: ", total_steps)
print("* Epsilon: ", network.epsilon)
print("*")
print("* Average: ", sum(rewards) / len(rewards), "/", sum(scores) / len(scores))
print("* Median: ", median(rewards), "/", median(scores))
print("* Mean: ", mean(rewards), "/", mean(scores))
print("* Min: ", min(rewards), "/", min(scores))
print("* Max: ", max(rewards), "/", max(scores))
print("==================")
env.close()
if __name__ == '__main__':
main()