Skip to content

Commit ed1eb92

Browse files
committed
Extract function and normalize images before calculating dataset statistics
1 parent b502366 commit ed1eb92

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

Diff for: src/reward_preprocessing/trainers/supervised_trainer.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313
import wandb
1414

1515

16+
def _normalize_obs(obs: th.Tensor) -> th.Tensor:
17+
"""Normalize by dividing by 255, if obs is uint8, otherwise no change."""
18+
if obs.dtype == th.uint8: # Observations saved as int => Normalize to [0, 1]
19+
obs = obs.float() / 255.0
20+
return obs
21+
22+
1623
class SupervisedTrainer(base.BaseImitationAlgorithm):
1724
"""Learns from demonstrations (transitions / trajectories) using supervised
1825
learning. Has some overlap with base.DemonstrationAlgorithm, but does not train a
@@ -227,10 +234,8 @@ def _data_dict_to_model_args_and_target(
227234
done = data_dict["dones"].to(device)
228235
target = data_dict["rews"].to(device)
229236

230-
if obs.dtype == th.uint8: # Observations saved as int => Normalize to [0, 1]
231-
obs = obs.float() / 255.0
232-
if next_obs.dtype == th.uint8:
233-
next_obs = next_obs.float() / 255.0
237+
obs = _normalize_obs(obs)
238+
next_obs = _normalize_obs(next_obs)
234239

235240
if isinstance(self.reward_net.action_space, spaces.Discrete):
236241
num_actions = self.reward_net.action_space.n
@@ -269,6 +274,8 @@ def _record_dataset_stats(self, key: str, dataloader: data.DataLoader) -> None:
269274
dones_count = 0
270275
for batch_idx, data_dict in enumerate(dataloader):
271276
obs = data_dict["obs"]
277+
obs = _normalize_obs(obs)
278+
272279
rew = data_dict["rews"]
273280
act = data_dict["acts"]
274281
done = data_dict["dones"]

0 commit comments

Comments
 (0)