5
5
import time
6
6
import sys
7
7
8
- from accum_trainer import AccumTrainer
9
8
from game_state import GameState
10
9
from game_state import ACTION_SIZE
11
10
from game_ac_network import GameACFFNetwork , GameACLSTMNetwork
@@ -39,18 +38,18 @@ def __init__(self,
39
38
40
39
self .local_network .prepare_loss (ENTROPY_BETA )
41
40
42
- # TODO: don't need accum trainer anymore with batch
43
- self . trainer = AccumTrainer ( device )
44
- self .trainer . prepare_minimize ( self . local_network . total_loss ,
45
- self .local_network .get_vars () )
46
-
47
- self . accum_gradients = self . trainer . accumulate_gradients ()
48
- self . reset_gradients = self . trainer . reset_gradients ( )
49
-
41
+ with tf . device ( device ):
42
+ var_refs = [ v . ref () for v in self . local_network . get_vars ()]
43
+ self .gradients = tf . gradients (
44
+ self .local_network .total_loss , var_refs ,
45
+ gate_gradients = False ,
46
+ aggregation_method = None ,
47
+ colocate_gradients_with_ops = False )
48
+
50
49
self .apply_gradients = grad_applier .apply_gradients (
51
50
global_network .get_vars (),
52
- self .trainer . get_accum_grad_list () )
53
-
51
+ self .gradients )
52
+
54
53
self .sync = self .local_network .sync_from (global_network )
55
54
56
55
self .game_state = GameState (113 * thread_index )
@@ -71,25 +70,14 @@ def _anneal_learning_rate(self, global_time_step):
71
70
return learning_rate
72
71
73
72
def choose_action (self , pi_values ):
74
- values = []
75
- sum = 0.0
76
- for rate in pi_values :
77
- sum = sum + rate
78
- value = sum
79
- values .append (value )
80
-
81
- r = random .random () * sum
82
- for i in range (len (values )):
83
- if values [i ] >= r :
84
- return i ;
85
- #fail safe
86
- return len (values )- 1
73
+ return np .random .choice (range (len (pi_values )), p = pi_values )
87
74
88
75
def _record_score (self , sess , summary_writer , summary_op , score_input , score , global_t ):
89
76
summary_str = sess .run (summary_op , feed_dict = {
90
77
score_input : score
91
78
})
92
79
summary_writer .add_summary (summary_str , global_t )
80
+ summary_writer .flush ()
93
81
94
82
def set_start_time (self , start_time ):
95
83
self .start_time = start_time
@@ -102,9 +90,6 @@ def process(self, sess, global_t, summary_writer, summary_op, score_input):
102
90
103
91
terminal_end = False
104
92
105
- # reset accumulated gradients
106
- sess .run ( self .reset_gradients )
107
-
108
93
# copy weights from shared to local
109
94
sess .run ( self .sync )
110
95
@@ -182,33 +167,32 @@ def process(self, sess, global_t, summary_writer, summary_op, score_input):
182
167
batch_td .append (td )
183
168
batch_R .append (R )
184
169
170
+ cur_learning_rate = self ._anneal_learning_rate (global_t )
171
+
185
172
if USE_LSTM :
186
173
batch_si .reverse ()
187
174
batch_a .reverse ()
188
175
batch_td .reverse ()
189
176
batch_R .reverse ()
190
177
191
- sess .run ( self .accum_gradients ,
178
+ sess .run ( self .apply_gradients ,
192
179
feed_dict = {
193
180
self .local_network .s : batch_si ,
194
181
self .local_network .a : batch_a ,
195
182
self .local_network .td : batch_td ,
196
183
self .local_network .r : batch_R ,
197
184
self .local_network .initial_lstm_state : start_lstm_state ,
198
- self .local_network .step_size : [len (batch_a )] } )
185
+ self .local_network .step_size : [len (batch_a )],
186
+ self .learning_rate_input : cur_learning_rate } )
199
187
else :
200
- sess .run ( self .accum_gradients ,
188
+ sess .run ( self .apply_gradients ,
201
189
feed_dict = {
202
190
self .local_network .s : batch_si ,
203
191
self .local_network .a : batch_a ,
204
192
self .local_network .td : batch_td ,
205
- self .local_network .r : batch_R } )
193
+ self .local_network .r : batch_R ,
194
+ self .learning_rate_input : cur_learning_rate } )
206
195
207
- cur_learning_rate = self ._anneal_learning_rate (global_t )
208
-
209
- sess .run ( self .apply_gradients ,
210
- feed_dict = { self .learning_rate_input : cur_learning_rate } )
211
-
212
196
if (self .thread_index == 0 ) and (self .local_t - self .prev_local_t >= PERFORMANCE_LOG_INTERVAL ):
213
197
self .prev_local_t += PERFORMANCE_LOG_INTERVAL
214
198
elapsed_time = time .time () - self .start_time
0 commit comments