Skip to content

Commit 1fd68cd

Browse files
committed
[Feature] Add deterministic_sample to masked categorical
ghstack-source-id: a6506ac99a83816ee3a1dbf44ce2795889310cdd Pull Request resolved: #2708
1 parent 7b8d5bb commit 1fd68cd

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torchrl/modules/distributions/discrete.py

+4
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,10 @@ def _mask_logits(
319319
logits.masked_fill_(padding_mask, neg_inf)
320320
return logits
321321

322+
@property
323+
def deterministic_sample(self):
324+
return self.mode
325+
322326

323327
class MaskedOneHotCategorical(MaskedCategorical):
324328
"""MaskedCategorical distribution.

0 commit comments

Comments
 (0)