Skip to content

Commit d72d749

Browse files
committed
update
2 parents d1ab5bc + 0fba8ab commit d72d749

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+160
-143
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
**/data
2-
*.pyc
2+
*.pyc
3+
metaworld
Binary file not shown.
Binary file not shown.

cs285pkg/cs285/agents/pg_agent.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from cs285.policies.MLP_policy import MLPPolicyPG
55
from cs285.infrastructure.replay_buffer import ReplayBuffer
66
from cs285.infrastructure.utils import normalize
7-
from cs285.policies.PSPPolicy import PSPPolicy
7+
88

99
class PGAgent(BaseAgent):
1010
def __init__(self, env, agent_params):
@@ -29,17 +29,6 @@ def __init__(self, env, agent_params):
2929
nn_baseline=self.agent_params["nn_baseline"],
3030
)
3131

32-
# self.actor = PSPPolicy(
33-
# self.agent_params["ac_dim"],
34-
# self.agent_params["ob_dim"],
35-
# self.agent_params["n_layers"],
36-
# self.agent_params["size"],
37-
# period=agent_params['period'],
38-
# discrete=self.agent_params["discrete"],
39-
# learning_rate=self.agent_params["learning_rate"],
40-
# nn_baseline=self.agent_params["nn_baseline"],
41-
# )
42-
4332
# replay buffer
4433
self.replay_buffer = ReplayBuffer(1000000)
4534

Binary file not shown.

cs285pkg/cs285/infrastructure/psp_layer.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,56 @@
99
class BinaryHashLinear(nn.Module):
1010
def __init__(self, n_in, n_out, period, key_pick='hash', learn_key=True):
1111
super(BinaryHashLinear, self).__init__()
12-
# self.key_pick = key_pick
13-
# w = nn.init.xavier_normal_(torch.empty(n_in, n_out))
14-
# rand_01 = np.random.binomial(p=.5, n=1, size=(n_in, period)).astype(np.float32)
15-
# o = torch.from_numpy(rand_01*2 - 1)
16-
# self.n_in = n_in
17-
# self.n_out = n_out
18-
#
19-
# self.w = nn.Parameter(w)
20-
# self.bias = nn.Parameter(torch.zeros(n_out))
21-
# self.o = nn.Parameter(o)
22-
self.linear = nn.Linear(n_in, n_out)
23-
# if not learn_key:
24-
# self.o.requires_grad = False
25-
26-
def forward(self, x, time=0):
27-
# o = self.o[:, int(time)]
28-
# m = x*o
29-
r = self.linear(x)
12+
self.key_pick = key_pick
13+
w = nn.init.xavier_normal_(torch.empty(n_in, n_out))
14+
rand_01 = np.random.binomial(p=.5, n=1, size=(n_in, period)).astype(np.float32)
15+
o = torch.from_numpy(rand_01*2 - 1)
16+
17+
self.w = nn.Parameter(w)
18+
self.bias = nn.Parameter(torch.zeros(n_out))
19+
self.o = nn.Parameter(o)
20+
if not learn_key:
21+
self.o.requires_grad = False
22+
23+
def forward(self, x, time):
24+
o = self.o[:, int(time)]
25+
m = x*o
26+
r = torch.mm(m, self.w)
3027
return r
28+
29+
30+
class HashLinear(nn.Module):
31+
'''Complex layer with complex diagonal contexts'''
32+
def __init__(self, n_in, n_out, period=2, key_pick='hash', learn_key=True):
33+
super(HashLinear, self).__init__()
34+
self.key_pick = key_pick
35+
w_r = nn.init.xavier_normal_(torch.empty(n_in, n_out))
36+
w_phi = torch.Tensor(n_in, n_out).uniform_(-np.pi, np.pi)
37+
o_r = torch.ones(period, n_in)
38+
o_phi = torch.Tensor(period, n_in).uniform_(-np.pi, np.pi)
39+
40+
self.w = nn.Parameter(torch.stack(from_polar(w_r, w_phi)))
41+
self.bias = nn.Parameter(torch.zeros(n_out))
42+
self.o = nn.Parameter(torch.stack(from_polar(o_r, o_phi)))
43+
if not learn_key:
44+
self.o.requires_grad = False
45+
46+
def forward(self, x_a, x_b, time):
47+
net_time = int(time) % self.o.shape[1]
48+
o = self.o[:, net_time]
49+
o_a = o[0].unsqueeze(0)
50+
o_b = o[1].unsqueeze(0)
51+
m_a = x_a*o_a - x_b*o_b
52+
m_b = x_b*o_a + x_a*o_b
53+
54+
w_a = self.w[0]
55+
w_b = self.w[1]
56+
r_a = torch.mm(m_a, w_a) - torch.mm(m_b, w_b)
57+
r_b = torch.mm(m_b, w_a) + torch.mm(m_a, w_b)
58+
return r_a + self.bias, r_b
59+
60+
61+
def from_polar(r, phi):
62+
a = r*torch.cos(phi)
63+
b = r*torch.sin(phi)
64+
return a, b
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import numpy as np
2+
import scipy
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
from torch.nn.modules.utils import _pair
7+
8+
9+
class BinaryHashLinear(nn.Module):
10+
def __init__(self, n_in, n_out, period, key_pick='hash', learn_key=True):
11+
super(BinaryHashLinear, self).__init__()
12+
self.key_pick = key_pick
13+
w = nn.init.xavier_normal_(torch.empty(n_in, n_out))
14+
rand_01 = np.random.binomial(p=.5, n=1, size=(n_in, period)).astype(np.float32)
15+
o = torch.from_numpy(rand_01*2 - 1)
16+
17+
self.w = nn.Parameter(w)
18+
self.bias = nn.Parameter(torch.zeros(n_out))
19+
self.o = nn.Parameter(o)
20+
if not learn_key:
21+
self.o.requires_grad = False
22+
23+
def forward(self, x, time):
24+
o = self.o[:, int(time)]
25+
m = x*o
26+
r = torch.mm(m, self.w)
27+
return r

cs285pkg/cs285/infrastructure/psp_layer2.py

Lines changed: 0 additions & 32 deletions
This file was deleted.
Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#from cs285.infrastructure.psp_layer import *
2-
from cs285.infrastructure.psp_layer2 import *
1+
from cs285.infrastructure.psp_layer import *
2+
33

44
class HashNet(nn.Module):
55
def __init__(self, input_dim, output_dim, layer_size,
@@ -29,12 +29,10 @@ def forward(self, x, time):
2929
r = self.activation(r)
3030
r = layer(r, time)
3131
preactivations.append(r)
32-
r = nn.Identity()(r)
3332

3433
return r, None, preactivations
3534

3635

37-
3836
class ComplexHashNet(HashNet):
3937
def forward(self, x, time):
4038
preactivations = []
@@ -43,26 +41,7 @@ def forward(self, x, time):
4341
if layer_i > 0:
4442
r_a = self.activation(r_a)
4543
r_b = self.activation(r_b)
46-
r_a = layer(r_a, time)
47-
r_b = layer(r_b, time)
44+
r_a, r_b = layer(r_a, r_b, time)
4845
preactivations.append(r_a)
49-
preactivations.append(r_b)
50-
return r_a, r_b, preactivations
51-
5246

53-
class MLP(nn.Module):
54-
55-
def __init__(self, input_size, output_size, n_layers, size, activation, output_activation):
56-
super(MLP, self).__init__()
57-
self.n_layers = n_layers
58-
self.linears = nn.ModuleList([nn.Linear(input_size, size)])
59-
self.linears.extend([nn.Linear(size, size) for i in range(0, self.n_layers - 1)])
60-
self.linears.append(nn.Linear(size, output_size))
61-
self.activation = activation
62-
self.output_activation = output_activation
63-
64-
def forward(self, x):
65-
for i in range(self.n_layers):
66-
x = self.activation(self.linears[i](x))
67-
mean = self.output_activation(self.linears[self.n_layers](x))
68-
return mean
47+
return r_a, r_b, preactivations

cs285pkg/cs285/infrastructure/pytorch_util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import torch
44
from torch import nn
5-
from cs285.infrastructure.psp_layer import BinaryHashLinear
65

76
Activation = Union[str, nn.Module]
87

cs285pkg/cs285/infrastructure/rl_trainer.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def run_training_loop(
129129

130130
for itr in range(n_iter):
131131
print("\n\n********** Iteration %i ************" % itr)
132+
self.agent.update_time(0)
132133

133134
# decide if videos should be rendered/logged at this iteration
134135
if (
@@ -239,11 +240,12 @@ def run_second_task_loop(
239240
# log/save
240241
if self.logvideo or self.logmetrics:
241242
# perform logging
243+
self.agent.update_time(0)
242244
print("\nBeginning logging procedure...")
243245
self.perform_logging(
244246
itr, paths, eval_policy, train_video_paths, train_logs
245247
)
246-
248+
self.agent.update_time(1)
247249
if self.params["save_params"]:
248250
self.agent.save(
249251
"{}/agent_itr_{}.pt".format(self.params["logdir"], itr)
@@ -328,30 +330,10 @@ def perform_logging(self, itr, paths, eval_policy, train_video_paths, all_logs):
328330
eval_paths, eval_envsteps_this_batch = utils.sample_trajectories(
329331
self.env, eval_policy, self.params["eval_batch_size"], self.params["ep_len"]
330332
)
331-
332-
# save eval rollouts as videos in tensorboard event file
333-
if self.logvideo and train_video_paths != None:
334-
print("\nCollecting video rollouts eval")
335-
eval_video_paths = utils.sample_n_trajectories(
336-
self.env, eval_policy, MAX_NVIDEO, MAX_VIDEO_LEN, True
337-
)
338-
339-
# save train/eval videos
340-
print("\nSaving train rollouts as videos...")
341-
self.logger.log_paths_as_videos(
342-
train_video_paths,
343-
itr,
344-
fps=self.fps,
345-
max_videos_to_save=MAX_NVIDEO,
346-
video_title="train_rollouts",
347-
)
348-
self.logger.log_paths_as_videos(
349-
eval_video_paths,
350-
itr,
351-
fps=self.fps,
352-
max_videos_to_save=MAX_NVIDEO,
353-
video_title="eval_rollouts",
354-
)
333+
self.agent.update_time(1)
334+
eval_paths2, eval_envsteps_this_batch2 = utils.sample_trajectories(
335+
self.env2, eval_policy, self.params["eval_batch_size"], self.params["ep_len"]
336+
)
355337

356338
#######################
357339

@@ -360,18 +342,22 @@ def perform_logging(self, itr, paths, eval_policy, train_video_paths, all_logs):
360342
# returns, for logging
361343
train_returns = [path["reward"].sum() for path in paths]
362344
eval_returns = [eval_path["reward"].sum() for eval_path in eval_paths]
345+
eval_returns2 = [eval_path["reward"].sum() for eval_path in eval_paths2]
363346

364347
# episode lengths, for logging
365348
train_ep_lens = [len(path["reward"]) for path in paths]
366349
eval_ep_lens = [len(eval_path["reward"]) for eval_path in eval_paths]
350+
eval_ep_lens2 = [len(eval_path["reward"]) for eval_path in eval_paths2]
367351

368352
# decide what to log
369353
logs = OrderedDict()
370354
logs["Eval_AverageReturn"] = np.mean(eval_returns)
355+
logs["Eval_AverageReturn2"] = np.mean(eval_returns2)
371356
logs["Eval_StdReturn"] = np.std(eval_returns)
372357
logs["Eval_MaxReturn"] = np.max(eval_returns)
373358
logs["Eval_MinReturn"] = np.min(eval_returns)
374359
logs["Eval_AverageEpLen"] = np.mean(eval_ep_lens)
360+
logs["Eval_AverageEpLen2"] = np.mean(eval_ep_lens2)
375361

376362
logs["Train_AverageReturn"] = np.mean(train_returns)
377363
logs["Train_StdReturn"] = np.std(train_returns)

cs285pkg/cs285/policies/MLP_policy.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111
from cs285.infrastructure import pytorch_util as ptu
1212
from cs285.policies.base_policy import BasePolicy
1313
from cs285.infrastructure.utils import normalize
14-
from cs285.infrastructure.psp_net import RealHashNet
15-
from cs285.infrastructure.psp_layer import BinaryHashLinear
16-
1714

1815

1916
class MLPPolicy(BasePolicy, nn.Module, metaclass=abc.ABCMeta):
@@ -60,7 +57,6 @@ def __init__(
6057
n_layers=self.n_layers,
6158
size=self.size,
6259
)
63-
# self.mean_net = RealHashNet(self.ob_dim, self.ac_dim, self.size, torch.tanh, self.n_layers, 2, 'hash', BinaryHashLinear)
6460
self.logstd = nn.Parameter(
6561
torch.zeros(self.ac_dim, dtype=torch.float32, device=ptu.device)
6662
)

cs285pkg/cs285/policies/PSPPolicy.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
from cs285.infrastructure import pytorch_util as ptu
1212
from cs285.policies.base_policy import BasePolicy
1313
from cs285.infrastructure.utils import normalize
14-
from cs285.infrastructure.psp_net import RealHashNet, ComplexHashNet, MLP
15-
from cs285.infrastructure.psp_layer import *
14+
from cs285.infrastructure.psp_net import RealHashNet, BinaryHashLinear, ComplexHashNet, HashLinear
1615

1716

1817
class PSPPolicy(BasePolicy, nn.Module, metaclass=abc.ABCMeta):
@@ -40,13 +39,9 @@ def __init__(
4039
self.learning_rate = learning_rate
4140
self.training = training
4241
self.period = period
43-
self.mean_net = MLP(self.ob_dim, self.ac_dim, n_layers, size, nn.Tanh(), nn.Identity())
44-
self.mean_net2 = ptu.build_mlp(
45-
input_size=self.ob_dim,
46-
output_size=self.ac_dim,
47-
n_layers=self.n_layers,
48-
size=self.size,
49-
)
42+
# self.mean_net = RealHashNet(self.ob_dim, self.ac_dim, self.size, torch.tanh, self.n_layers, self.period, 'hash', BinaryHashLinear)
43+
self.mean_net = ComplexHashNet(self.ob_dim, self.ac_dim, self.size, torch.tanh, self.n_layers, self.period, 'hash',
44+
HashLinear)
5045
self.logstd = nn.Parameter(
5146
torch.zeros(self.ac_dim, dtype=torch.float32, device=ptu.device)
5247
)
@@ -56,7 +51,6 @@ def __init__(
5651
itertools.chain([self.logstd], self.mean_net.parameters()),
5752
self.learning_rate,
5853
)
59-
self.a = self.mean_net.parameters()
6054

6155
def update_time(self, time):
6256
self.time = time
@@ -86,10 +80,8 @@ def update(self, observations, acs_na, adv_n=None, acs_labels_na=None,
8680
qvals=None):
8781
observations = ptu.from_numpy(observations)
8882
actions = ptu.from_numpy(acs_na)
89-
adv_n = ptu.from_numpy(adv_n)
90-
9183
action_distribution = self(observations)
92-
loss = - action_distribution.log_prob(actions) * adv_n
84+
loss = -action_distribution.log_prob(actions) * ptu.from_numpy(adv_n)
9385
loss = loss.mean()
9486

9587
self.optimizer.zero_grad()
@@ -106,12 +98,11 @@ def update(self, observations, acs_na, adv_n=None, acs_labels_na=None,
10698
# return more flexible objects, such as a
10799
# `torch.distributions.Distribution` object. It's up to you!
108100
def forward(self, observation: torch.FloatTensor):
109-
batch_mean = self.mean_net(observation)
110-
# batch_mean2 = self.mean_net2(observation)
101+
batch_mean = self.mean_net(observation, self.time)[0]
111102
scale_tril = torch.diag(torch.exp(self.logstd))
112103
batch_dim = batch_mean.shape[0]
113104
batch_scale_tril = scale_tril.repeat(batch_dim, 1, 1)
114105
action_distribution = distributions.MultivariateNormal(
115106
batch_mean, scale_tril=batch_scale_tril,
116107
)
117-
return action_distribution
108+
return action_distribution
Binary file not shown.
1.88 KB
Binary file not shown.

0 commit comments

Comments
 (0)