-
-
Notifications
You must be signed in to change notification settings - Fork 934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add probability masking to space.sample
#1310
Merged
pseudo-rnd-thoughts
merged 18 commits into
Farama-Foundation:main
from
pseudo-rnd-thoughts:probability_mask
Feb 21, 2025
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
713aa42
hook fixed styling
mariojerez 94fbd34
Updated invalid probability tests so that they catch the assertion er…
mariojerez 967a061
Corrected test_invalid_probability_mask tests and corrected issue wit…
mariojerez 8839e03
reformatted comment
mariojerez af94493
Added probability to sample method of box, dict, sequence, and tuple
mariojerez d29f055
Fixed error in message that would have shown in documentation
mariojerez 0f6f7b2
Improved documentation for discrete and space
mariojerez 3005c92
Added probability mask to graph space
mariojerez c876115
Added probability mask to remaining spaces and refactored code to imp…
mariojerez dc64293
Finished up editing sample methods. Added tests.
mariojerez 765442a
Added and improved tests for box, discrete, graph, multi-discrete, oneof
mariojerez 5fcabe4
Wrote sample method tests for Sequence space
mariojerez 0aae7ac
finalized tests and made a small correction in documentation
mariojerez 4bd7fe5
Update probability mask implementation
pseudo-rnd-thoughts 2c5db93
Fixed tests for NumPy <2.0
pseudo-rnd-thoughts d45cc5f
Merge branch 'Farama-Foundation:main' into probability_mask
pseudo-rnd-thoughts 386c5f4
Remove mujoco-py from `docs/requirements.txt`
pseudo-rnd-thoughts bb16e65
Code review by mariojerez
pseudo-rnd-thoughts File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,19 +59,29 @@ def is_np_flattenable(self): | |
"""Checks whether this space can be flattened to a :class:`spaces.Box`.""" | ||
return True | ||
|
||
def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]: | ||
def sample( | ||
self, mask: MaskNDArray | None = None, probability: MaskNDArray | None = None | ||
) -> NDArray[np.int8]: | ||
"""Generates a single random sample from this space. | ||
|
||
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space). | ||
|
||
Args: | ||
mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``. | ||
For ``mask == 0`` then the samples will be ``0`` and ``mask == 1` then random samples will be generated. | ||
mask: An optional ``np.ndarray`` to mask samples with expected shape of ``space.shape``. | ||
For ``mask == 0`` then the samples will be ``0``, for a ``mask == 1`` then the samples will be ``1``. | ||
For random samples, using a mask value of ``2``. | ||
The expected mask shape is the space shape and mask dtype is ``np.int8``. | ||
probability: An optional ``np.ndarray`` to mask samples with expected shape of space.shape where each element | ||
represents the probability of the corresponding sample element being a 1. | ||
The expected mask shape is the space shape and mask dtype is ``np.float64``. | ||
|
||
Returns: | ||
Sampled values from space | ||
""" | ||
if mask is not None and probability is not None: | ||
raise ValueError( | ||
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}" | ||
) | ||
if mask is not None: | ||
assert isinstance( | ||
mask, np.ndarray | ||
|
@@ -91,8 +101,25 @@ def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]: | |
self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype), | ||
mask.astype(self.dtype), | ||
) | ||
elif probability is not None: | ||
assert isinstance( | ||
probability, np.ndarray | ||
), f"The expected type of the probability is np.ndarray, actual type: {type(probability)}" | ||
assert ( | ||
probability.dtype == np.float64 | ||
), f"The expected dtype of the probability is np.float64, actual dtype: {probability.dtype}" | ||
assert ( | ||
probability.shape == self.shape | ||
), f"The expected shape of the probability is {self.shape}, actual shape: {probability}" | ||
assert np.all( | ||
np.logical_and(probability >= 0, probability <= 1) | ||
), f"All values of the sample probability should be between 0 and 1, actual values: {probability}" | ||
|
||
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype) | ||
return (self.np_random.random(size=self.shape) <= probability).astype( | ||
self.dtype | ||
) | ||
else: | ||
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing tests for probability sampling MultiBinary |
||
def contains(self, x: Any) -> bool: | ||
"""Return boolean specifying if x is a valid member of this space.""" | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that the Args descriptions are a little confusing. I think something like this would make it more clear.
mask: An optional
np.ndarray
to mask samples with expected shape ofspace.shape
.When an element in
mask
is0
, then the corresponding element in the sample will be0
.When an element in
mask
is1
, then the corresponding element in the sample will be1
.When an element in
mask
is2
, then the corresponding element in the sample will be be randomly sampled.The expected mask shape is the space shape and mask dtype is
np.int8
.probability: An optional
np.ndarray
to mask samples with expected shape ofspace.shape
where each element represents the probability of the corresponding sample element being1
.The expected mask shape is the space shape and mask dtype is
np.float64
.