-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
68 lines (58 loc) · 2.16 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# import tensorflow as tf
import numpy as np
import pygame
from Env import Env
from reinforcement_learning.dq_agent import QNNAgent
# if __name__ == '__main__':
env = Env(display=True)
# curr_state = np.array(env.observation_space.ravel())
agent = QNNAgent(env)
# while not env.done:
# action = agent.get_action(curr_state.reshape(-1,24))
# next_state, reward, = env.update_step(action)
# next_state = next_state.ravel()
repetitions = 0
train = True
reward = 0
total_rewards = []
repetition = 10
repetitions += repetition
episodes = 100
name = f"DQN-Agent-{repetitions}x{episodes}"
if train:
# Results tend to even out at around 10 repetitions
for rep in range(repetition):
total_reward = 0
for ep in range(episodes):
observations = env.reset()
done = False
steps = 0
while not done:
steps += 1
action = agent.get_action(observations.reshape((-1,24)))
next_observations, reward = env.update_step(action)
if steps > 10000:
env.done = True
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
break
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_ESCAPE:
env.done = True
done = env.done
agent.train(observations.reshape((-1,24)), next_observations.reshape((-1,24)), action, reward, done)
observations = next_observations
print(f"Step: {steps}\t Repetition: {rep}\t Total Rewards: {total_rewards}\nEpisode: {ep}\t Total Reward: {total_reward}\t Action: {action}\t eps: {agent.eps}")
# env.render()
# clear_output(wait = True)
total_reward += reward
total_rewards.append(total_reward)
agent.model.save(".\\dqn_agent.h5")
# end_training(dqagent, total_rewards, name)
# Load the previous model and results
# else:
# train = load_model(dqagent, name)
# train = True
# if train:
# print("Train set to true")