Skip to content

Commit 6ca2a63

Browse files
committed
work for TL2 TF2
1 parent f88f23e commit 6ca2a63

5 files changed

+206
-280
lines changed

examples/reinforcement_learning/tutorial_atari_pong.py

+14-37
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,11 @@
2929
import time
3030

3131
import numpy as np
32-
import tensorflow as tf
3332

3433
import gym
34+
import tensorflow as tf
3535
import tensorlayer as tl
3636

37-
## enable eager mode
38-
tf.enable_eager_execution()
39-
40-
41-
tf.logging.set_verbosity(tf.logging.DEBUG) # enable logging
4237
tl.logging.set_verbosity(tl.logging.DEBUG)
4338

4439
# hyper-parameters
@@ -52,7 +47,7 @@
5247
render = False # display the game environment
5348
# resume = True # load existing policy network
5449
model_file_name = "model_pong"
55-
np.set_printoptions(threshold=np.nan)
50+
np.set_printoptions(threshold=np.inf)
5651

5752

5853
def prepro(I):
@@ -73,35 +68,23 @@ def prepro(I):
7368
episode_number = 0
7469

7570
xs, ys, rs = [], [], []
76-
# observation for training and inference
77-
# t_states = tf.placeholder(tf.float32, shape=[None, D])
78-
# policy network
7971

72+
73+
# policy network
8074
def get_model(inputs_shape):
8175
ni = tl.layers.Input(inputs_shape)
8276
nn = tl.layers.Dense(n_units=H, act=tf.nn.relu, name='hidden')(ni)
8377
nn = tl.layers.Dense(n_units=3, name='output')(nn)
8478
M = tl.models.Model(inputs=ni, outputs=nn, name="mlp")
8579
return M
80+
81+
8682
model = get_model([None, D])
8783
train_weights = model.trainable_weights
88-
# probs = model(t_states, is_train=True).outputs
89-
# sampling_prob = tf.nn.softmax(probs)
90-
91-
# t_actions = tf.placeholder(tf.int32, shape=[None])
92-
# t_discount_rewards = tf.placeholder(tf.float32, shape=[None])
93-
# loss = tl.rein.cross_entropy_reward_loss(probs, t_actions, t_discount_rewards)
94-
optimizer = tf.train.RMSPropOptimizer(learning_rate, decay_rate)#.minimize(loss)
95-
96-
# with tf.Session() as sess:
97-
# sess.run(tf.global_variables_initializer())
98-
# if resume: TODO
99-
# load_params = tl.files.load_npz(name=model_file_name+'.npz')
100-
# tl.files.assign_params(sess, load_params, network)
101-
# tl.files.load_and_assign_npz(sess, model_file_name + '.npz', network)
102-
# network.print_params()
103-
# network.print_layers()
104-
model.train() # set model to train mode (in case you add dropout into the model)
84+
85+
optimizer = tf.optimizers.RMSprop(lr=learning_rate, decay=decay_rate)
86+
87+
model.train() # set model to train mode (in case you add dropout into the model)
10588

10689
start_time = time.time()
10790
game_number = 0
@@ -114,14 +97,12 @@ def get_model(inputs_shape):
11497
x = x.reshape(1, D)
11598
prev_x = cur_x
11699

117-
# prob = sess.run(sampling_prob, feed_dict={t_states: x})
118-
_prob = model(x).outputs
100+
_prob = model(x)
119101
prob = tf.nn.softmax(_prob)
120102

121103
# action. 1: STOP 2: UP 3: DOWN
122-
# action = np.random.choice([1,2,3], p=prob.flatten())
123-
# action = tl.rein.choice_action_by_probs(prob.flatten(), [1, 2, 3])
124-
# action = np.random.choice([1,2,3], p=prob.numpy())
104+
# action = np.random.choice([1,2,3], p=prob.flatten())
105+
# action = tl.rein.choice_action_by_probs(prob.flatten(), [1, 2, 3])
125106
action = tl.rein.choice_action_by_probs(prob[0].numpy(), [1, 2, 3])
126107

127108
observation, reward, done, _ = env.step(action)
@@ -145,12 +126,8 @@ def get_model(inputs_shape):
145126

146127
xs, ys, rs = [], [], []
147128

148-
# sess.run(train_op, feed_dict={t_states: epx, t_actions: epy, t_discount_rewards: disR})
149-
# t_actions = tf.placeholder(tf.int32, shape=[None])
150-
# t_discount_rewards = tf.placeholder(tf.float32, shape=[None])
151-
# loss = tl.rein.cross_entropy_reward_loss(probs, t_actions, t_discount_rewards)
152129
with tf.GradientTape() as tape:
153-
_prob = model(epx).outputs
130+
_prob = model(epx)
154131
_loss = tl.rein.cross_entropy_reward_loss(_prob, epy, disR)
155132
grad = tape.gradient(_loss, train_weights)
156133
optimizer.apply_gradients(zip(grad, train_weights))

0 commit comments

Comments
 (0)