Skip to content

Commit 37ce310

Browse files
author
Omegastick
committed
Improve const-correctness
1 parent 5ee613e commit 37ce310

File tree

4 files changed

+10
-10
lines changed

4 files changed

+10
-10
lines changed

include/cpprl/model/policy.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,17 @@ class PolicyImpl : public nn::Module
3333

3434
std::vector<torch::Tensor> act(torch::Tensor inputs,
3535
torch::Tensor rnn_hxs,
36-
torch::Tensor masks);
36+
torch::Tensor masks) const;
3737
std::vector<torch::Tensor> evaluate_actions(torch::Tensor inputs,
3838
torch::Tensor rnn_hxs,
3939
torch::Tensor masks,
40-
torch::Tensor actions);
40+
torch::Tensor actions) const;
4141
torch::Tensor get_probs(torch::Tensor inputs,
4242
torch::Tensor rnn_hxs,
43-
torch::Tensor masks);
43+
torch::Tensor masks) const;
4444
torch::Tensor get_values(torch::Tensor inputs,
4545
torch::Tensor rnn_hxs,
46-
torch::Tensor masks);
46+
torch::Tensor masks) const;
4747
void update_observation_normalizer(torch::Tensor observations);
4848

4949
inline bool is_recurrent() const { return base->is_recurrent(); }

include/cpprl/observation_normalizer.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class ObservationNormalizerImpl : public torch::nn::Module
2323
float clip = 10.);
2424
explicit ObservationNormalizerImpl(const std::vector<ObservationNormalizer> &others);
2525

26-
torch::Tensor process_observation(torch::Tensor observation);
26+
torch::Tensor process_observation(torch::Tensor observation) const;
2727
std::vector<float> get_mean() const;
2828
std::vector<float> get_variance() const;
2929
void update(torch::Tensor observations);

src/model/policy.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ PolicyImpl::PolicyImpl(ActionSpace action_space,
5151

5252
std::vector<torch::Tensor> PolicyImpl::act(torch::Tensor inputs,
5353
torch::Tensor rnn_hxs,
54-
torch::Tensor masks)
54+
torch::Tensor masks) const
5555
{
5656
if (observation_normalizer)
5757
{
@@ -83,7 +83,7 @@ std::vector<torch::Tensor> PolicyImpl::act(torch::Tensor inputs,
8383
std::vector<torch::Tensor> PolicyImpl::evaluate_actions(torch::Tensor inputs,
8484
torch::Tensor rnn_hxs,
8585
torch::Tensor masks,
86-
torch::Tensor actions)
86+
torch::Tensor actions) const
8787
{
8888
if (observation_normalizer)
8989
{
@@ -116,7 +116,7 @@ std::vector<torch::Tensor> PolicyImpl::evaluate_actions(torch::Tensor inputs,
116116

117117
torch::Tensor PolicyImpl::get_probs(torch::Tensor inputs,
118118
torch::Tensor rnn_hxs,
119-
torch::Tensor masks)
119+
torch::Tensor masks) const
120120
{
121121
if (observation_normalizer)
122122
{
@@ -131,7 +131,7 @@ torch::Tensor PolicyImpl::get_probs(torch::Tensor inputs,
131131

132132
torch::Tensor PolicyImpl::get_values(torch::Tensor inputs,
133133
torch::Tensor rnn_hxs,
134-
torch::Tensor masks)
134+
torch::Tensor masks) const
135135
{
136136
if (observation_normalizer)
137137
{

src/observation_normalizer.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ ObservationNormalizerImpl::ObservationNormalizerImpl(const std::vector<Observati
6666
rms->set_count(total_count);
6767
}
6868

69-
torch::Tensor ObservationNormalizerImpl::process_observation(torch::Tensor observation)
69+
torch::Tensor ObservationNormalizerImpl::process_observation(torch::Tensor observation) const
7070
{
7171
auto normalized_obs = (observation - rms->get_mean()) /
7272
torch::sqrt(rms->get_variance() + 1e-8);

0 commit comments

Comments
 (0)