Skip to content

Commit 78fe566

Browse files
committed
remove accum_trainer
1 parent 7367edd commit 78fe566

7 files changed

+102
-279
lines changed

a3c_display.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,7 @@
1818
from constants import USE_LSTM
1919

2020
def choose_action(pi_values):
21-
values = []
22-
sum = 0.0
23-
for rate in pi_values:
24-
sum = sum + rate
25-
value = sum
26-
values.append(value)
27-
28-
r = random.random() * sum
29-
for i in range(len(values)):
30-
if values[i] >= r:
31-
return i;
32-
#fail safe
33-
return len(values)-1
21+
return np.random.choice(range(len(pi_values)), p=pi_values)
3422

3523
# use CPU for display tool
3624
device = "/cpu:0"
@@ -49,15 +37,6 @@ def choose_action(pi_values):
4937
clip_norm = GRAD_NORM_CLIP,
5038
device = device)
5139

52-
# training_threads = []
53-
# for i in range(PARALLEL_SIZE):
54-
# training_thread = A3CTrainingThread(i, global_network, 1.0,
55-
# learning_rate_input,
56-
# grad_applier,
57-
# 8000000,
58-
# device = device)
59-
# training_threads.append(training_thread)
60-
6140
sess = tf.Session()
6241
init = tf.initialize_all_variables()
6342
sess.run(init)
@@ -78,5 +57,8 @@ def choose_action(pi_values):
7857
action = choose_action(pi_values)
7958
game_state.process(action)
8059

81-
game_state.update()
60+
if game_state.terminal:
61+
game_state.reset()
62+
else:
63+
game_state.update()
8264

a3c_training_thread.py

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import time
66
import sys
77

8-
from accum_trainer import AccumTrainer
98
from game_state import GameState
109
from game_state import ACTION_SIZE
1110
from game_ac_network import GameACFFNetwork, GameACLSTMNetwork
@@ -39,18 +38,18 @@ def __init__(self,
3938

4039
self.local_network.prepare_loss(ENTROPY_BETA)
4140

42-
# TODO: don't need accum trainer anymore with batch
43-
self.trainer = AccumTrainer(device)
44-
self.trainer.prepare_minimize( self.local_network.total_loss,
45-
self.local_network.get_vars() )
46-
47-
self.accum_gradients = self.trainer.accumulate_gradients()
48-
self.reset_gradients = self.trainer.reset_gradients()
49-
41+
with tf.device(device):
42+
var_refs = [v.ref() for v in self.local_network.get_vars()]
43+
self.gradients = tf.gradients(
44+
self.local_network.total_loss, var_refs,
45+
gate_gradients=False,
46+
aggregation_method=None,
47+
colocate_gradients_with_ops=False)
48+
5049
self.apply_gradients = grad_applier.apply_gradients(
5150
global_network.get_vars(),
52-
self.trainer.get_accum_grad_list() )
53-
51+
self.gradients )
52+
5453
self.sync = self.local_network.sync_from(global_network)
5554

5655
self.game_state = GameState(113 * thread_index)
@@ -71,25 +70,14 @@ def _anneal_learning_rate(self, global_time_step):
7170
return learning_rate
7271

7372
def choose_action(self, pi_values):
74-
values = []
75-
sum = 0.0
76-
for rate in pi_values:
77-
sum = sum + rate
78-
value = sum
79-
values.append(value)
80-
81-
r = random.random() * sum
82-
for i in range(len(values)):
83-
if values[i] >= r:
84-
return i;
85-
#fail safe
86-
return len(values)-1
73+
return np.random.choice(range(len(pi_values)), p=pi_values)
8774

8875
def _record_score(self, sess, summary_writer, summary_op, score_input, score, global_t):
8976
summary_str = sess.run(summary_op, feed_dict={
9077
score_input: score
9178
})
9279
summary_writer.add_summary(summary_str, global_t)
80+
summary_writer.flush()
9381

9482
def set_start_time(self, start_time):
9583
self.start_time = start_time
@@ -102,9 +90,6 @@ def process(self, sess, global_t, summary_writer, summary_op, score_input):
10290

10391
terminal_end = False
10492

105-
# reset accumulated gradients
106-
sess.run( self.reset_gradients )
107-
10893
# copy weights from shared to local
10994
sess.run( self.sync )
11095

@@ -182,33 +167,32 @@ def process(self, sess, global_t, summary_writer, summary_op, score_input):
182167
batch_td.append(td)
183168
batch_R.append(R)
184169

170+
cur_learning_rate = self._anneal_learning_rate(global_t)
171+
185172
if USE_LSTM:
186173
batch_si.reverse()
187174
batch_a.reverse()
188175
batch_td.reverse()
189176
batch_R.reverse()
190177

191-
sess.run( self.accum_gradients,
178+
sess.run( self.apply_gradients,
192179
feed_dict = {
193180
self.local_network.s: batch_si,
194181
self.local_network.a: batch_a,
195182
self.local_network.td: batch_td,
196183
self.local_network.r: batch_R,
197184
self.local_network.initial_lstm_state: start_lstm_state,
198-
self.local_network.step_size : [len(batch_a)] } )
185+
self.local_network.step_size : [len(batch_a)],
186+
self.learning_rate_input: cur_learning_rate } )
199187
else:
200-
sess.run( self.accum_gradients,
188+
sess.run( self.apply_gradients,
201189
feed_dict = {
202190
self.local_network.s: batch_si,
203191
self.local_network.a: batch_a,
204192
self.local_network.td: batch_td,
205-
self.local_network.r: batch_R} )
193+
self.local_network.r: batch_R,
194+
self.learning_rate_input: cur_learning_rate} )
206195

207-
cur_learning_rate = self._anneal_learning_rate(global_t)
208-
209-
sess.run( self.apply_gradients,
210-
feed_dict = { self.learning_rate_input: cur_learning_rate } )
211-
212196
if (self.thread_index == 0) and (self.local_t - self.prev_local_t >= PERFORMANCE_LOG_INTERVAL):
213197
self.prev_local_t += PERFORMANCE_LOG_INTERVAL
214198
elapsed_time = time.time() - self.start_time

accum_trainer.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

accum_trainer_test.py

Lines changed: 0 additions & 89 deletions
This file was deleted.

constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@
1717
ENTROPY_BETA = 0.01 # entropy regurarlization constant
1818
MAX_TIME_STEP = 10 * 10**7
1919
GRAD_NORM_CLIP = 40.0 # gradient norm clipping
20-
USE_GPU = False # To use GPU, set True
21-
USE_LSTM = False # True for A3C LSTM, False for A3C FF
20+
USE_GPU = True # To use GPU, set True
21+
USE_LSTM = True # True for A3C LSTM, False for A3C FF

0 commit comments

Comments
 (0)