Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add probability masking to space.sample #1310

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
713aa42
hook fixed styling
mariojerez Jan 15, 2025
94fbd34
Updated invalid probability tests so that they catch the assertion er…
mariojerez Jan 15, 2025
967a061
Corrected test_invalid_probability_mask tests and corrected issue wit…
mariojerez Jan 15, 2025
8839e03
reformatted comment
mariojerez Jan 15, 2025
af94493
Added probability to sample method of box, dict, sequence, and tuple
mariojerez Jan 17, 2025
d29f055
Fixed error in message that would have shown in documentation
mariojerez Jan 17, 2025
0f6f7b2
Improved documentation for discrete and space
mariojerez Jan 17, 2025
3005c92
Added probability mask to graph space
mariojerez Jan 17, 2025
c876115
Added probability mask to remaining spaces and refactored code to imp…
mariojerez Jan 18, 2025
dc64293
Finished up editing sample methods. Added tests.
mariojerez Jan 20, 2025
765442a
Added and improved tests for box, discrete, graph, multi-discrete, oneof
mariojerez Jan 21, 2025
5fcabe4
Wrote sample method tests for Sequence space
mariojerez Jan 22, 2025
0aae7ac
finalized tests and made a small correction in documentation
mariojerez Jan 22, 2025
4bd7fe5
Update probability mask implementation
pseudo-rnd-thoughts Feb 17, 2025
2c5db93
Fixed tests for NumPy <2.0
pseudo-rnd-thoughts Feb 18, 2025
d45cc5f
Merge branch 'Farama-Foundation:main' into probability_mask
pseudo-rnd-thoughts Feb 18, 2025
386c5f4
Remove mujoco-py from `docs/requirements.txt`
pseudo-rnd-thoughts Feb 19, 2025
bb16e65
Code review by mariojerez
pseudo-rnd-thoughts Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion gymnasium/spaces/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def is_bounded(self, manner: str = "both") -> bool:
f"manner is not in {{'below', 'above', 'both'}}, actual value: {manner}"
)

def sample(self, mask: None = None) -> NDArray[Any]:
def sample(self, mask: None = None, probability: None = None) -> NDArray[Any]:
r"""Generates a single random sample inside the Box.

In creating a sample of the box, each coordinate is sampled (independently) from a distribution
Expand All @@ -355,6 +355,7 @@ def sample(self, mask: None = None) -> NDArray[Any]:

Args:
mask: A mask for sampling values from the Box space, currently unsupported.
probability: A probability mask for sampling values from the Box space, currently unsupported.

Returns:
A sampled value from the Box
Expand All @@ -363,6 +364,10 @@ def sample(self, mask: None = None) -> NDArray[Any]:
raise gym.error.Error(
f"Box.sample cannot be provided a mask, actual value: {mask}"
)
elif probability is not None:
raise gym.error.Error(
f"Box.sample cannot be provided a probability mask, actual value: {probability}"
)

high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
sample = np.empty(self.shape)
Expand Down
32 changes: 27 additions & 5 deletions gymnasium/spaces/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,27 +149,49 @@ def seed(self, seed: int | dict[str, Any] | None = None) -> dict[str, int]:
f"Expected seed type: dict, int or None, actual type: {type(seed)}"
)

def sample(self, mask: dict[str, Any] | None = None) -> dict[str, Any]:
def sample(
self,
mask: dict[str, Any] | None = None,
probability: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Generates a single random sample from this space.

The sample is an ordered dictionary of independent samples from the constituent spaces.

Args:
mask: An optional mask for each of the subspaces, expects the same keys as the space
probability: An optional probability mask for each of the subspaces, expects the same keys as the space

Returns:
A dictionary with the same key and sampled values from :attr:`self.spaces`
"""
if mask is not None:
if mask is not None and probability is not None:
raise ValueError(
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
)
elif mask is not None:
assert isinstance(
mask, dict
), f"Expects mask to be a dict, actual type: {type(mask)}"
), f"Expected sample mask to be a dict, actual type: {type(mask)}"
assert (
mask.keys() == self.spaces.keys()
), f"Expect mask keys to be same as space keys, mask keys: {mask.keys()}, space keys: {self.spaces.keys()}"
), f"Expected sample mask keys to be same as space keys, mask keys: {mask.keys()}, space keys: {self.spaces.keys()}"

return {k: space.sample(mask=mask[k]) for k, space in self.spaces.items()}
elif probability is not None:
assert isinstance(
probability, dict
), f"Expected sample probability mask to be a dict, actual type: {type(probability)}"
assert (
probability.keys() == self.spaces.keys()
), f"Expected sample probability mask keys to be same as space keys, mask keys: {probability.keys()}, space keys: {self.spaces.keys()}"

return {k: space.sample() for k, space in self.spaces.items()}
return {
k: space.sample(probability=probability[k])
for k, space in self.spaces.items()
}
else:
return {k: space.sample() for k, space in self.spaces.items()}

def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
Expand Down
55 changes: 47 additions & 8 deletions gymnasium/spaces/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class Discrete(Space[np.int64]):
>>> observation_space = Discrete(3, start=-1, seed=42) # {-1, 0, 1}
>>> observation_space.sample()
np.int64(-1)
>>> observation_space.sample(mask=np.array([0,0,1], dtype=np.int8))
np.int64(1)
>>> observation_space.sample(probability=np.array([0,0,1], dtype=np.float64))
np.int64(1)
>>> observation_space.sample(probability=np.array([0,0.3,0.7], dtype=np.float64))
np.int64(1)
"""

def __init__(
Expand Down Expand Up @@ -56,41 +62,74 @@ def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True

def sample(self, mask: MaskNDArray | None = None) -> np.int64:
def sample(
self, mask: MaskNDArray | None = None, probability: MaskNDArray | None = None
) -> np.int64:
"""Generates a single random sample from this space.

A sample will be chosen uniformly at random with the mask if provided
A sample will be chosen uniformly at random with the mask if provided, or it will be chosen according to a specified probability distribution if the probability mask is provided.

Args:
mask: An optional mask for if an action can be selected.
Expected `np.ndarray` of shape ``(n,)`` and dtype ``np.int8`` where ``1`` represents valid actions and ``0`` invalid / infeasible actions.
If there are no possible actions (i.e. ``np.all(mask == 0)``) then ``space.start`` will be returned.
probability: An optional probability mask describing the probability of each action being selected.
Expected `np.ndarray` of shape ``(n,)`` and dtype ``np.float64`` where each value is in the range ``[0, 1]`` and the sum of all values is 1.
If the values do not sum to 1, an exception will be thrown.

Returns:
A sampled integer from the space
"""
if mask is not None:
if mask is not None and probability is not None:
raise ValueError(
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
)
# binary mask sampling
elif mask is not None:
assert isinstance(
mask, np.ndarray
), f"The expected type of the mask is np.ndarray, actual type: {type(mask)}"
), f"The expected type of the sample mask is np.ndarray, actual type: {type(mask)}"
assert (
mask.dtype == np.int8
), f"The expected dtype of the mask is np.int8, actual dtype: {mask.dtype}"
), f"The expected dtype of the sample mask is np.int8, actual dtype: {mask.dtype}"
assert mask.shape == (
self.n,
), f"The expected shape of the mask is {(self.n,)}, actual shape: {mask.shape}"
), f"The expected shape of the sample mask is {(int(self.n),)}, actual shape: {mask.shape}"

valid_action_mask = mask == 1
assert np.all(
np.logical_or(mask == 0, valid_action_mask)
), f"All values of a mask should be 0 or 1, actual values: {mask}"
), f"All values of the sample mask should be 0 or 1, actual values: {mask}"

if np.any(valid_action_mask):
return self.start + self.np_random.choice(
np.where(valid_action_mask)[0]
)
else:
return self.start
# probability mask sampling
elif probability is not None:
assert isinstance(
probability, np.ndarray
), f"The expected type of the sample probability is np.ndarray, actual type: {type(probability)}"
assert (
probability.dtype == np.float64
), f"The expected dtype of the sample probability is np.float64, actual dtype: {probability.dtype}"
assert probability.shape == (
self.n,
), f"The expected shape of the sample probability is {(int(self.n),)}, actual shape: {probability.shape}"

return self.start + self.np_random.integers(self.n)
assert np.all(
np.logical_and(probability >= 0, probability <= 1)
), f"All values of the sample probability should be between 0 and 1, actual values: {probability}"
assert np.isclose(
np.sum(probability), 1
), f"The sum of the sample probability should be equal to 1, actual sum: {np.sum(probability)}"

return self.start + self.np_random.choice(np.arange(self.n), p=probability)
# uniform sampling
else:
return self.start + self.np_random.integers(self.n)

def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
Expand Down
39 changes: 30 additions & 9 deletions gymnasium/spaces/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ def sample(
NDArray[Any] | tuple[Any, ...] | None,
]
) = None,
probability: None | (
tuple[
NDArray[Any] | tuple[Any, ...] | None,
NDArray[Any] | tuple[Any, ...] | None,
]
) = None,
num_nodes: int = 10,
num_edges: int | None = None,
) -> GraphInstance:
Expand All @@ -192,6 +198,9 @@ def sample(
mask: An optional tuple of optional node and edge mask that is only possible with Discrete spaces
(Box spaces don't support sample masks).
If no ``num_edges`` is provided then the ``edge_mask`` is multiplied by the number of edges
probability: An optional tuple of optional node and edge probability mask that is only possible with Discrete spaces
(Box spaces don't support sample probability masks).
If no ``num_edges`` is provided then the ``edge_mask`` is multiplied by the number of edges
num_nodes: The number of nodes that will be sampled, the default is `10` nodes
num_edges: An optional number of edges, otherwise, a random number between `0` and :math:`num_nodes^2`

Expand All @@ -202,10 +211,18 @@ def sample(
num_nodes > 0
), f"The number of nodes is expected to be greater than 0, actual value: {num_nodes}"

if mask is not None:
if mask is not None and probability is not None:
raise ValueError(
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
)
elif mask is not None:
node_space_mask, edge_space_mask = mask
mask_type = "mask"
elif probability is not None:
node_space_mask, edge_space_mask = probability
mask_type = "probability"
else:
node_space_mask, edge_space_mask = None, None
node_space_mask = edge_space_mask = mask_type = None

# we only have edges when we have at least 2 nodes
if num_edges is None:
Expand All @@ -228,15 +245,19 @@ def sample(
assert num_edges is not None

sampled_node_space = self._generate_sample_space(self.node_space, num_nodes)
assert sampled_node_space is not None
sampled_edge_space = self._generate_sample_space(self.edge_space, num_edges)

assert sampled_node_space is not None
sampled_nodes = sampled_node_space.sample(node_space_mask)
sampled_edges = (
sampled_edge_space.sample(edge_space_mask)
if sampled_edge_space is not None
else None
)
if mask_type is not None:
node_sample_kwargs = {mask_type: node_space_mask}
edge_sample_kwargs = {mask_type: edge_space_mask}
else:
node_sample_kwargs = edge_sample_kwargs = {}

sampled_nodes = sampled_node_space.sample(**node_sample_kwargs)
sampled_edges = None
if sampled_edge_space is not None:
sampled_edges = sampled_edge_space.sample(**edge_sample_kwargs)

sampled_edge_links = None
if sampled_edges is not None and num_edges > 0:
Expand Down
35 changes: 31 additions & 4 deletions gymnasium/spaces/multi_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,29 @@ def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True

def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]:
def sample(
self, mask: MaskNDArray | None = None, probability: MaskNDArray | None = None
) -> NDArray[np.int8]:
"""Generates a single random sample from this space.

A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).

Args:
mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``.
For ``mask == 0`` then the samples will be ``0`` and ``mask == 1` then random samples will be generated.
mask: An optional ``np.ndarray`` to mask samples with expected shape of ``space.shape``.
For ``mask == 0`` then the samples will be ``0``, for a ``mask == 1`` then the samples will be ``1``.
For random samples, using a mask value of ``2``.
The expected mask shape is the space shape and mask dtype is ``np.int8``.
probability: An optional ``np.ndarray`` to mask samples with expected shape of space.shape where each element
represents the probability of the corresponding sample element being a 1.
The expected mask shape is the space shape and mask dtype is ``np.float64``.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the Args descriptions are a little confusing. I think something like this would make it more clear.

mask: An optional np.ndarray to mask samples with expected shape of space.shape.
When an element in mask is 0, then the corresponding element in the sample will be 0.
When an element in mask is 1, then the corresponding element in the sample will be 1.
When an element in mask is 2, then the corresponding element in the sample will be be randomly sampled.
The expected mask shape is the space shape and mask dtype is np.int8.

probability: An optional np.ndarray to mask samples with expected shape of space.shape where each element represents the probability of the corresponding sample element being 1.
The expected mask shape is the space shape and mask dtype is np.float64.

Returns:
Sampled values from space
"""
if mask is not None and probability is not None:
raise ValueError(
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
)
if mask is not None:
assert isinstance(
mask, np.ndarray
Expand All @@ -91,8 +101,25 @@ def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]:
self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype),
mask.astype(self.dtype),
)
elif probability is not None:
assert isinstance(
probability, np.ndarray
), f"The expected type of the probability is np.ndarray, actual type: {type(probability)}"
assert (
probability.dtype == np.float64
), f"The expected dtype of the probability is np.float64, actual dtype: {probability.dtype}"
assert (
probability.shape == self.shape
), f"The expected shape of the probability is {self.shape}, actual shape: {probability}"
assert np.all(
np.logical_and(probability >= 0, probability <= 1)
), f"All values of the sample probability should be between 0 and 1, actual values: {probability}"

return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
return (self.np_random.random(size=self.shape) <= probability).astype(
self.dtype
)
else:
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing tests for probability sampling MultiBinary

def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
Expand Down
Loading
Loading