Skip to content

Commit 5c13b4a

Browse files
committed
replace CustomBasicLSTMCell to original BasicLSTMCell
1 parent 78fe566 commit 5c13b4a

6 files changed

+21
-147
lines changed

a3c.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def log_uniform(lo, hi, rate):
4949
if USE_LSTM:
5050
global_network = GameACLSTMNetwork(ACTION_SIZE, -1, device)
5151
else:
52-
global_network = GameACFFNetwork(ACTION_SIZE, device)
52+
global_network = GameACFFNetwork(ACTION_SIZE, -1, device)
5353

5454

5555
training_threads = []

a3c_display.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def choose_action(pi_values):
2626
if USE_LSTM:
2727
global_network = GameACLSTMNetwork(ACTION_SIZE, -1, device)
2828
else:
29-
global_network = GameACFFNetwork(ACTION_SIZE, device)
29+
global_network = GameACFFNetwork(ACTION_SIZE, -1, device)
3030

3131
learning_rate_input = tf.placeholder("float")
3232

a3c_training_thread.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self,
3434
if USE_LSTM:
3535
self.local_network = GameACLSTMNetwork(ACTION_SIZE, thread_index, device)
3636
else:
37-
self.local_network = GameACFFNetwork(ACTION_SIZE, device)
37+
self.local_network = GameACFFNetwork(ACTION_SIZE, thread_index, device)
3838

3939
self.local_network.prepare_loss(ENTROPY_BETA)
4040

a3c_visualize.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
if USE_LSTM:
2828
global_network = GameACLSTMNetwork(ACTION_SIZE, -1, device)
2929
else:
30-
global_network = GameACFFNetwork(ACTION_SIZE, device)
30+
global_network = GameACFFNetwork(ACTION_SIZE, -1, device)
3131

3232
training_threads = []
3333

@@ -40,13 +40,6 @@
4040
clip_norm = GRAD_NORM_CLIP,
4141
device = device)
4242

43-
# for i in range(PARALLEL_SIZE):
44-
# training_thread = A3CTrainingThread(i, global_network, 1.0,
45-
# learning_rate_input,
46-
# grad_applier, MAX_TIME_STEP,
47-
# device = device)
48-
# training_threads.append(training_thread)
49-
5043
sess = tf.Session()
5144
init = tf.initialize_all_variables()
5245
sess.run(init)

custom_lstm.py

-125
This file was deleted.

game_ac_network.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
# -*- coding: utf-8 -*-
22
import tensorflow as tf
33
import numpy as np
4-
from custom_lstm import CustomBasicLSTMCell
54

65
# Actor-Critic Network Base Class
76
# (Policy network and Value network)
87
class GameACNetwork(object):
98
def __init__(self,
109
action_size,
10+
thread_index, # -1 for global
1111
device="/cpu:0"):
12-
self._device = device
1312
self._action_size = action_size
13+
self._thread_index = thread_index
14+
self._device = device
1415

1516
def prepare_loss(self, entropy_beta):
1617
with tf.device(self._device):
@@ -94,10 +95,12 @@ def _conv2d(self, x, W, stride):
9495
class GameACFFNetwork(GameACNetwork):
9596
def __init__(self,
9697
action_size,
98+
thread_index, # -1 for global
9799
device="/cpu:0"):
98-
GameACNetwork.__init__(self, action_size, device)
99-
100-
with tf.device(self._device):
100+
GameACNetwork.__init__(self, action_size, thread_index, device)
101+
102+
scope_name = "net_" + str(self._thread_index)
103+
with tf.device(self._device), tf.variable_scope(scope_name) as scope:
101104
self.W_conv1, self.b_conv1 = self._conv_variable([8, 8, 4, 16]) # stride=4
102105
self.W_conv2, self.b_conv2 = self._conv_variable([4, 4, 16, 32]) # stride=2
103106

@@ -149,16 +152,17 @@ def __init__(self,
149152
action_size,
150153
thread_index, # -1 for global
151154
device="/cpu:0" ):
152-
GameACNetwork.__init__(self, action_size, device)
155+
GameACNetwork.__init__(self, action_size, thread_index, device)
153156

154-
with tf.device(self._device):
157+
scope_name = "net_" + str(self._thread_index)
158+
with tf.device(self._device), tf.variable_scope(scope_name) as scope:
155159
self.W_conv1, self.b_conv1 = self._conv_variable([8, 8, 4, 16]) # stride=4
156160
self.W_conv2, self.b_conv2 = self._conv_variable([4, 4, 16, 32]) # stride=2
157161

158162
self.W_fc1, self.b_fc1 = self._fc_variable([2592, 256])
159163

160164
# lstm
161-
self.lstm = CustomBasicLSTMCell(256, state_is_tuple=True)
165+
self.lstm = tf.nn.rnn_cell.BasicLSTMCell(256, state_is_tuple=True)
162166

163167
# weight for policy output layer
164168
self.W_fc2, self.b_fc2 = self._fc_variable([256, action_size])
@@ -187,8 +191,6 @@ def __init__(self,
187191
self.initial_lstm_state = tf.nn.rnn_cell.LSTMStateTuple(self.initial_lstm_state0,
188192
self.initial_lstm_state1)
189193

190-
scope = "net_" + str(thread_index)
191-
192194
# Unrolling LSTM up to LOCAL_T_MAX time steps. (= 5time steps.)
193195
# When episode terminates unrolling time steps becomes less than LOCAL_TIME_STEP.
194196
# Unrolling step size is applied via self.step_size placeholder.
@@ -212,6 +214,10 @@ def __init__(self,
212214
v_ = tf.matmul(lstm_outputs, self.W_fc3) + self.b_fc3
213215
self.v = tf.reshape( v_, [-1] )
214216

217+
scope.reuse_variables()
218+
self.W_lstm = tf.get_variable("BasicLSTMCell/Linear/Matrix")
219+
self.b_lstm = tf.get_variable("BasicLSTMCell/Linear/Bias")
220+
215221
self.reset_state()
216222

217223
def reset_state(self):
@@ -259,6 +265,6 @@ def get_vars(self):
259265
return [self.W_conv1, self.b_conv1,
260266
self.W_conv2, self.b_conv2,
261267
self.W_fc1, self.b_fc1,
262-
self.lstm.matrix, self.lstm.bias,
268+
self.W_lstm, self.b_lstm,
263269
self.W_fc2, self.b_fc2,
264270
self.W_fc3, self.b_fc3]

0 commit comments

Comments
 (0)