1
1
# -*- coding: utf-8 -*-
2
2
import tensorflow as tf
3
3
import numpy as np
4
- from custom_lstm import CustomBasicLSTMCell
5
4
6
5
# Actor-Critic Network Base Class
7
6
# (Policy network and Value network)
8
7
class GameACNetwork (object ):
9
8
def __init__ (self ,
10
9
action_size ,
10
+ thread_index , # -1 for global
11
11
device = "/cpu:0" ):
12
- self ._device = device
13
12
self ._action_size = action_size
13
+ self ._thread_index = thread_index
14
+ self ._device = device
14
15
15
16
def prepare_loss (self , entropy_beta ):
16
17
with tf .device (self ._device ):
@@ -94,10 +95,12 @@ def _conv2d(self, x, W, stride):
94
95
class GameACFFNetwork (GameACNetwork ):
95
96
def __init__ (self ,
96
97
action_size ,
98
+ thread_index , # -1 for global
97
99
device = "/cpu:0" ):
98
- GameACNetwork .__init__ (self , action_size , device )
99
-
100
- with tf .device (self ._device ):
100
+ GameACNetwork .__init__ (self , action_size , thread_index , device )
101
+
102
+ scope_name = "net_" + str (self ._thread_index )
103
+ with tf .device (self ._device ), tf .variable_scope (scope_name ) as scope :
101
104
self .W_conv1 , self .b_conv1 = self ._conv_variable ([8 , 8 , 4 , 16 ]) # stride=4
102
105
self .W_conv2 , self .b_conv2 = self ._conv_variable ([4 , 4 , 16 , 32 ]) # stride=2
103
106
@@ -149,16 +152,17 @@ def __init__(self,
149
152
action_size ,
150
153
thread_index , # -1 for global
151
154
device = "/cpu:0" ):
152
- GameACNetwork .__init__ (self , action_size , device )
155
+ GameACNetwork .__init__ (self , action_size , thread_index , device )
153
156
154
- with tf .device (self ._device ):
157
+ scope_name = "net_" + str (self ._thread_index )
158
+ with tf .device (self ._device ), tf .variable_scope (scope_name ) as scope :
155
159
self .W_conv1 , self .b_conv1 = self ._conv_variable ([8 , 8 , 4 , 16 ]) # stride=4
156
160
self .W_conv2 , self .b_conv2 = self ._conv_variable ([4 , 4 , 16 , 32 ]) # stride=2
157
161
158
162
self .W_fc1 , self .b_fc1 = self ._fc_variable ([2592 , 256 ])
159
163
160
164
# lstm
161
- self .lstm = CustomBasicLSTMCell (256 , state_is_tuple = True )
165
+ self .lstm = tf . nn . rnn_cell . BasicLSTMCell (256 , state_is_tuple = True )
162
166
163
167
# weight for policy output layer
164
168
self .W_fc2 , self .b_fc2 = self ._fc_variable ([256 , action_size ])
@@ -187,8 +191,6 @@ def __init__(self,
187
191
self .initial_lstm_state = tf .nn .rnn_cell .LSTMStateTuple (self .initial_lstm_state0 ,
188
192
self .initial_lstm_state1 )
189
193
190
- scope = "net_" + str (thread_index )
191
-
192
194
# Unrolling LSTM up to LOCAL_T_MAX time steps. (= 5time steps.)
193
195
# When episode terminates unrolling time steps becomes less than LOCAL_TIME_STEP.
194
196
# Unrolling step size is applied via self.step_size placeholder.
@@ -212,6 +214,10 @@ def __init__(self,
212
214
v_ = tf .matmul (lstm_outputs , self .W_fc3 ) + self .b_fc3
213
215
self .v = tf .reshape ( v_ , [- 1 ] )
214
216
217
+ scope .reuse_variables ()
218
+ self .W_lstm = tf .get_variable ("BasicLSTMCell/Linear/Matrix" )
219
+ self .b_lstm = tf .get_variable ("BasicLSTMCell/Linear/Bias" )
220
+
215
221
self .reset_state ()
216
222
217
223
def reset_state (self ):
@@ -259,6 +265,6 @@ def get_vars(self):
259
265
return [self .W_conv1 , self .b_conv1 ,
260
266
self .W_conv2 , self .b_conv2 ,
261
267
self .W_fc1 , self .b_fc1 ,
262
- self .lstm . matrix , self .lstm . bias ,
268
+ self .W_lstm , self .b_lstm ,
263
269
self .W_fc2 , self .b_fc2 ,
264
270
self .W_fc3 , self .b_fc3 ]
0 commit comments