Skip to content

Commit 49d9897

Browse files
committed
[Feature] Add deterministic_sample to masked categorical
ghstack-source-id: d34fcf9b44d7a7c60dbde80b0835189f990ef226 Pull Request resolved: #2708
1 parent d425777 commit 49d9897

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torchrl/modules/distributions/discrete.py

+4
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,10 @@ def _mask_logits(
319319
logits.masked_fill_(padding_mask, neg_inf)
320320
return logits
321321

322+
@property
323+
def deterministic_sample(self):
324+
return self.mode
325+
322326

323327
class MaskedOneHotCategorical(MaskedCategorical):
324328
"""MaskedCategorical distribution.

0 commit comments

Comments
 (0)