Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brax now supports Single-layer vmap (Hpo Problem) #212

Merged
merged 2 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/evox/operators/crossover/differential_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def DE_differential_sum(

select_len = num_diff_vectors.unsqueeze(1) * 2 + 1
rand_indices = torch.randint(0, pop_size, (pop_size, diff_padding_num), device=device)
rand_indices = torch.where(rand_indices == index.unsqueeze(1), torch.tensor(pop_size - 1, device=device), rand_indices)
rand_indices = torch.where(rand_indices == index.unsqueeze(1), pop_size - 1, rand_indices)

pop_permute = population[rand_indices]
mask = torch.arange(diff_padding_num, device=device).unsqueeze(0) < select_len
pop_permute_padding = torch.where(mask.unsqueeze(2), pop_permute, torch.zeros_like(pop_permute))
pop_permute_padding = torch.where(mask.unsqueeze(2), pop_permute, 0)

diff_vectors = pop_permute_padding[:, 1:]
difference_sum = diff_vectors[:, 0::2].sum(dim=1) - diff_vectors[:, 1::2].sum(dim=1)
Expand Down
32 changes: 23 additions & 9 deletions src/evox/problems/neuroevolution/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from brax import envs
from brax.io import html, image

from ...core import Problem, jit_class
from ...core import Problem, _vmap_fix, jit_class, vmap_impl
from .utils import get_vmap_model_state_forward


Expand Down Expand Up @@ -70,7 +70,8 @@ def __init__(
The initial key is obtained from `torch.random.get_rng_state()`.

## Warning
This problem does NOT support HPO wrapper (`problems.hpo_wrapper.HPOProblemWrapper`), i.e., the workflow containing this problem CANNOT be vmapped.
This problem does NOT support HPO wrapper (`problems.hpo_wrapper.HPOProblemWrapper`) out-of-box, i.e., the workflow containing this problem CANNOT be vmapped.
*However*, by setting `pop_size` to the multiplication of inner population size and outer population size, you can still use this problem in a HPO workflow.

## Examples
>>> from evox import problems
Expand Down Expand Up @@ -137,13 +138,7 @@ def __init__(
self.rotate_key = rotate_key
self.reduce_fn = reduce_fn

def evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor:
"""Evaluate the final rewards of a population (batch) of model parameters.

:param pop_params: A dictionary of parameters where each key is a parameter name and each value is a tensor of shape (batch_size, *param_shape) representing the batched parameters of batched models.

:return: A tensor of shape (batch_size,) containing the reward of each sample in the population.
"""
def _normal_evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor:
# Unpack parameters and buffers
state_params = {self._param_to_state_key_map[key]: value for key, value in pop_params.items()}
model_state = dict(self._vmap_model_buffers)
Expand All @@ -157,6 +152,25 @@ def evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor:
# Return
return rewards

def evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor:
"""Evaluate the final rewards of a population (batch) of model parameters.

:param pop_params: A dictionary of parameters where each key is a parameter name and each value is a tensor of shape (batch_size, *param_shape) representing the batched parameters of batched models.

:return: A tensor of shape (batch_size,) containing the reward of each sample in the population.
"""
return self._normal_evaluate(pop_params)

@vmap_impl(evaluate)
def _vmap_evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor:
_, vmap_dim, vmap_size = _vmap_fix.unwrap_batch_tensor(list(pop_params.values())[0])
assert vmap_dim == (0,)
vmap_size = vmap_size[0]
pop_params = {k: _vmap_fix.unwrap_batch_tensor(v)[0].view(vmap_size * v.size(0), *v.size()[1:]) for k, v in pop_params.items()}
flat_rewards = self._normal_evaluate(pop_params)
rewards = flat_rewards.view(vmap_size, flat_rewards.size(0) // vmap_size, *flat_rewards.size()[1:])
return _vmap_fix.wrap_batch_tensor(rewards, vmap_dim)

def _model_forward(
self, model_state: Dict[str, torch.Tensor], obs: torch.Tensor, record_trajectory: bool = False
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
Expand Down
Loading