@@ -29,6 +29,9 @@ def config():
2929 # Only evaluate test loss on 4 batches when you're in the middle of a train epoch.
3030 # Set to None to evaluate on the whole test set.
3131 test_subset_within_epoch = 4
32+ # flag to train classification for whether reward is 0 or not, rather than
33+ # regression.
34+ classify = False
3235 # use adversarial training. below are configs to be set if adversarial is set to
3336 # True. for details, see documentation of SupervisedTrainer in
3437 # trainers/supervised_trainer.py
@@ -103,19 +106,33 @@ def make_trainer(
103106 limit_samples : int ,
104107 test_subset_within_epoch : Optional [int ],
105108 opt_kwargs : Optional [Mapping [str , Any ]],
109+ classify : bool ,
106110 adversarial : bool ,
107111 start_epoch : Optional [int ],
108112 nonsense_reward : Optional [float ],
109113 num_acts : Optional [int ],
110114 vis_frac_per_epoch : Optional [float ],
111115 gradient_clip_percentile : Optional [float ],
116+ device : str ,
112117 debugging : Mapping ,
113118) -> SupervisedTrainer :
114119 if not adversarial :
115- # MSE loss with mean reduction (the default)
116- # Mean reduction means every batch affects model updates the same, regardless of
117- # batch_size.
118- loss_fn = th .nn .MSELoss ()
120+ if not classify :
121+ # MSE loss with mean reduction (the default)
122+ # Mean reduction means every batch affects model updates the same,
123+ # regardless of batch_size.
124+ loss_fn = th .nn .MSELoss ()
125+ else :
126+ # loss function takes outputs (interpreted as log-probability reward is
127+ # zero), reward, and computes the cross-entropy loss.
128+ def loss_fn (input , target ):
129+ if len (input .shape ) == 1 :
130+ input = input [:, None ]
131+ zeros = th .zeros (input .shape ).to (device )
132+ log_probs = th .cat ((input , zeros ), dim = 1 )
133+ target_classes = (target != 0 ).long ()
134+ return th .nn .CrossEntropyLoss ()(log_probs , target_classes )
135+
119136 else :
120137 # Huber loss with mean reduction
121138 # When the prediction is within a distance of sqrt(3) of the regression target,
0 commit comments