|
13 | 13 | import wandb
|
14 | 14 |
|
15 | 15 |
|
| 16 | +def _normalize_obs(obs: th.Tensor) -> th.Tensor: |
| 17 | + """Normalize by dividing by 255, if obs is uint8, otherwise no change.""" |
| 18 | + if obs.dtype == th.uint8: # Observations saved as int => Normalize to [0, 1] |
| 19 | + obs = obs.float() / 255.0 |
| 20 | + return obs |
| 21 | + |
| 22 | + |
16 | 23 | class SupervisedTrainer(base.BaseImitationAlgorithm):
|
17 | 24 | """Learns from demonstrations (transitions / trajectories) using supervised
|
18 | 25 | learning. Has some overlap with base.DemonstrationAlgorithm, but does not train a
|
@@ -227,10 +234,8 @@ def _data_dict_to_model_args_and_target(
|
227 | 234 | done = data_dict["dones"].to(device)
|
228 | 235 | target = data_dict["rews"].to(device)
|
229 | 236 |
|
230 |
| - if obs.dtype == th.uint8: # Observations saved as int => Normalize to [0, 1] |
231 |
| - obs = obs.float() / 255.0 |
232 |
| - if next_obs.dtype == th.uint8: |
233 |
| - next_obs = next_obs.float() / 255.0 |
| 237 | + obs = _normalize_obs(obs) |
| 238 | + next_obs = _normalize_obs(next_obs) |
234 | 239 |
|
235 | 240 | if isinstance(self.reward_net.action_space, spaces.Discrete):
|
236 | 241 | num_actions = self.reward_net.action_space.n
|
@@ -269,6 +274,8 @@ def _record_dataset_stats(self, key: str, dataloader: data.DataLoader) -> None:
|
269 | 274 | dones_count = 0
|
270 | 275 | for batch_idx, data_dict in enumerate(dataloader):
|
271 | 276 | obs = data_dict["obs"]
|
| 277 | + obs = _normalize_obs(obs) |
| 278 | + |
272 | 279 | rew = data_dict["rews"]
|
273 | 280 | act = data_dict["acts"]
|
274 | 281 | done = data_dict["dones"]
|
|
0 commit comments