-
-
Notifications
You must be signed in to change notification settings - Fork 934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add probability masking to space.sample
#1310
Changes from 17 commits
713aa42
94fbd34
967a061
8839e03
af94493
d29f055
0f6f7b2
3005c92
c876115
dc64293
765442a
5fcabe4
0aae7ac
4bd7fe5
2c5db93
d45cc5f
386c5f4
bb16e65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,19 +59,28 @@ def is_np_flattenable(self): | |
"""Checks whether this space can be flattened to a :class:`spaces.Box`.""" | ||
return True | ||
|
||
def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]: | ||
def sample( | ||
self, mask: MaskNDArray | None = None, probability: None = None | ||
) -> NDArray[np.int8]: | ||
"""Generates a single random sample from this space. | ||
|
||
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space). | ||
|
||
Args: | ||
mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``. | ||
For ``mask == 0`` then the samples will be ``0`` and ``mask == 1` then random samples will be generated. | ||
mask: An optional ``np.ndarray`` to mask samples with expected shape of ``space.shape``. | ||
For ``mask == 0`` then the samples will be ``0``, for a ``mask == 1`` then the samples will be ``1``. | ||
For random samples, using a mask value of ``2``. | ||
The expected mask shape is the space shape and mask dtype is ``np.int8``. | ||
probability: An optional ``np.ndarray`` to mask samples with expected shape of ``space.shape`` where element | ||
represent the probability of 1 | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that the Args descriptions are a little confusing. I think something like this would make it more clear. mask: An optional probability: An optional |
||
Returns: | ||
Sampled values from space | ||
""" | ||
if mask is not None and probability is not None: | ||
raise ValueError( | ||
"Only one of `mask` or `probability` can be provided, and `probability` is currently unsupported" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should remove "and |
||
) | ||
if mask is not None: | ||
assert isinstance( | ||
mask, np.ndarray | ||
|
@@ -91,8 +100,25 @@ def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]: | |
self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype), | ||
mask.astype(self.dtype), | ||
) | ||
elif probability is not None: | ||
assert isinstance( | ||
probability, np.ndarray | ||
), f"The expected type of the probability is np.ndarray, actual type: {type(probability)}" | ||
assert ( | ||
probability.dtype == np.float64 | ||
), f"The expected dtype of the probability is np.float64, actual dtype: {probability.dtype}" | ||
assert ( | ||
probability.shape == self.shape | ||
), f"The expected shape of the probability is {self.shape}, actual shape: {probability}" | ||
assert np.all( | ||
np.logical_and(probability >= 0, probability <= 1) | ||
), f"All values of the sample probability should be between 0 and 1, actual values: {probability}" | ||
|
||
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype) | ||
return (self.np_random.random(size=self.shape) <= probability).astype( | ||
self.dtype | ||
) | ||
else: | ||
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing tests for probability sampling MultiBinary |
||
def contains(self, x: Any) -> bool: | ||
"""Return boolean specifying if x is a valid member of this space.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think should be...
mask: MaskNDArray | None = None, probability: MaskNDArray | None = None