Skip to content

Commit e1844a1

Browse files
author
Omegastick
committed
Update to PyTorch 1.3
1 parent 1832c60 commit e1844a1

File tree

6 files changed

+11
-11
lines changed

6 files changed

+11
-11
lines changed

.gitmodules

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
[submodule "lib/spdlog"]
1+
[submodule "example/lib/spdlog"]
22
path = example/lib/spdlog
33
url = [email protected]:gabime/spdlog.git
4-
[submodule "lib/msgpack-c"]
4+
[submodule "example/lib/msgpack-c"]
55
path = example/lib/msgpack-c
66
url = [email protected]:msgpack/msgpack-c.git
77
[submodule "example/lib/libzmq"]

example/gym_client.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ int main(int argc, char *argv[])
182182
storage.get_masks()[step]);
183183
}
184184
auto actions_tensor = act_result[1].cpu().to(torch::kFloat);
185-
float *actions_array = actions_tensor.data<float>();
185+
float *actions_array = actions_tensor.data_ptr<float>();
186186
std::vector<std::vector<float>> actions(num_envs);
187187
for (int i = 0; i < num_envs; ++i)
188188
{
@@ -218,7 +218,7 @@ int main(int argc, char *argv[])
218218
returns_rms->update(returns);
219219
reward_tensor = torch::clamp(reward_tensor / torch::sqrt(returns_rms->get_variance() + 1e-8),
220220
-reward_clip_value, reward_clip_value);
221-
rewards = std::vector<float>(reward_tensor.data<float>(), reward_tensor.data<float>() + reward_tensor.numel());
221+
rewards = std::vector<float>(reward_tensor.data_ptr<float>(), reward_tensor.data_ptr<float>() + reward_tensor.numel());
222222
real_rewards = flatten_vector(step_result->real_reward);
223223
dones_vec = step_result->done;
224224
}
@@ -233,7 +233,7 @@ int main(int argc, char *argv[])
233233
returns_rms->update(returns);
234234
reward_tensor = torch::clamp(reward_tensor / torch::sqrt(returns_rms->get_variance() + 1e-8),
235235
-reward_clip_value, reward_clip_value);
236-
rewards = std::vector<float>(reward_tensor.data<float>(), reward_tensor.data<float>() + reward_tensor.numel());
236+
rewards = std::vector<float>(reward_tensor.data_ptr<float>(), reward_tensor.data_ptr<float>() + reward_tensor.numel());
237237
real_rewards = flatten_vector(step_result->real_reward);
238238
dones_vec = step_result->done;
239239
}

src/algorithms/a2c.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ A2C::A2C(Policy &policy,
3636
std::vector<UpdateDatum> A2C::update(RolloutStorage &rollouts, float decay_level)
3737
{
3838
// Decay learning rate
39-
optimizer->options.learning_rate_ = original_learning_rate * decay_level;
39+
optimizer->options.learning_rate(original_learning_rate * decay_level);
4040

4141
// Prep work
4242
auto full_obs_shape = rollouts.get_observations().sizes();

src/algorithms/ppo.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
4545
{
4646
// Decay lr and clip parameter
4747
float clip_param = original_clip_param * decay_level;
48-
optimizer->options.learning_rate_ = original_learning_rate * decay_level;
48+
optimizer->options.learning_rate(original_learning_rate * decay_level);
4949

5050
// Calculate advantages
5151
auto returns = rollouts.get_returns();

src/model/nn_base.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ std::vector<torch::Tensor> NNBase::forward_gru(torch::Tensor x,
8181
// has_zeros = [0] + has_zeros + [timesteps]
8282
has_zeros = has_zeros.contiguous().to(torch::kInt);
8383
std::vector<int> has_zeros_vec(
84-
has_zeros.data<int>(),
85-
has_zeros.data<int>() + has_zeros.numel());
84+
has_zeros.data_ptr<int>(),
85+
has_zeros.data_ptr<int>() + has_zeros.numel());
8686
has_zeros_vec.insert(has_zeros_vec.begin(), {0});
8787
has_zeros_vec.push_back(timesteps);
8888

src/observation_normalizer.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ torch::Tensor ObservationNormalizerImpl::process_observation(torch::Tensor obser
7676
std::vector<float> ObservationNormalizerImpl::get_mean() const
7777
{
7878
auto mean = rms->get_mean();
79-
return std::vector<float>(mean.data<float>(), mean.data<float>() + mean.numel());
79+
return std::vector<float>(mean.data_ptr<float>(), mean.data_ptr<float>() + mean.numel());
8080
}
8181

8282
std::vector<float> ObservationNormalizerImpl::get_variance() const
8383
{
8484
auto variance = rms->get_variance();
85-
return std::vector<float>(variance.data<float>(), variance.data<float>() + variance.numel());
85+
return std::vector<float>(variance.data_ptr<float>(), variance.data_ptr<float>() + variance.numel());
8686
}
8787

8888
void ObservationNormalizerImpl::update(torch::Tensor observations)

0 commit comments

Comments
 (0)