Skip to content

Commit b01fefe

Browse files
authored
Merge pull request #47 from HumanCompatibleAI/reward-classification
Reward classification
2 parents d9746eb + 324bb29 commit b01fefe

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

src/reward_preprocessing/scripts/common/supervised.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ def config():
2929
# Only evaluate test loss on 4 batches when you're in the middle of a train epoch.
3030
# Set to None to evaluate on the whole test set.
3131
test_subset_within_epoch = 4
32+
# flag to train classification for whether reward is 0 or not, rather than
33+
# regression.
34+
classify = False
3235
# use adversarial training. below are configs to be set if adversarial is set to
3336
# True. for details, see documentation of SupervisedTrainer in
3437
# trainers/supervised_trainer.py
@@ -103,19 +106,33 @@ def make_trainer(
103106
limit_samples: int,
104107
test_subset_within_epoch: Optional[int],
105108
opt_kwargs: Optional[Mapping[str, Any]],
109+
classify: bool,
106110
adversarial: bool,
107111
start_epoch: Optional[int],
108112
nonsense_reward: Optional[float],
109113
num_acts: Optional[int],
110114
vis_frac_per_epoch: Optional[float],
111115
gradient_clip_percentile: Optional[float],
116+
device: str,
112117
debugging: Mapping,
113118
) -> SupervisedTrainer:
114119
if not adversarial:
115-
# MSE loss with mean reduction (the default)
116-
# Mean reduction means every batch affects model updates the same, regardless of
117-
# batch_size.
118-
loss_fn = th.nn.MSELoss()
120+
if not classify:
121+
# MSE loss with mean reduction (the default)
122+
# Mean reduction means every batch affects model updates the same,
123+
# regardless of batch_size.
124+
loss_fn = th.nn.MSELoss()
125+
else:
126+
# loss function takes outputs (interpreted as log-probability reward is
127+
# zero), reward, and computes the cross-entropy loss.
128+
def loss_fn(input, target):
129+
if len(input.shape) == 1:
130+
input = input[:, None]
131+
zeros = th.zeros(input.shape).to(device)
132+
log_probs = th.cat((input, zeros), dim=1)
133+
target_classes = (target != 0).long()
134+
return th.nn.CrossEntropyLoss()(log_probs, target_classes)
135+
119136
else:
120137
# Huber loss with mean reduction
121138
# When the prediction is within a distance of sqrt(3) of the regression target,

src/reward_preprocessing/scripts/train_regression.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def train_regression(supervised, checkpoint_epoch_interval: int): # From ingred
6464
model=model,
6565
custom_logger=custom_logger,
6666
num_acts=num_acts,
67+
device=device,
6768
)
6869

6970
trainer.log_data_stats()

0 commit comments

Comments
 (0)