Skip to content

Commit 3f7993d

Browse files
committed
[Feature] Add deterministic_sample to masked categorical
ghstack-source-id: e5046c9 Pull Request resolved: #2708
1 parent e0a78e4 commit 3f7993d

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torchrl/modules/distributions/discrete.py

+4
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,10 @@ def _mask_logits(
319319
logits.masked_fill_(padding_mask, neg_inf)
320320
return logits
321321

322+
@property
323+
def deterministic_sample(self):
324+
return self.mode
325+
322326

323327
class MaskedOneHotCategorical(MaskedCategorical):
324328
"""MaskedCategorical distribution.

0 commit comments

Comments
 (0)