Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
874e659
dev: add state_io functionality
BillHuang2001 Aug 6, 2024
ad65d96
Code cleaning
BillHuang2001 Aug 7, 2024
0942606
dev: register closure_values as part of the pytree container
BillHuang2001 Aug 7, 2024
3a37ba0
Code cleaning
BillHuang2001 Aug 8, 2024
980cbf9
dev: merge closures in the state with use_state
BillHuang2001 Aug 8, 2024
b8b7248
dev: allow manually executing callbacks without clearing the list
BillHuang2001 Aug 8, 2024
86781ed
dev: monitor use the state_io api
BillHuang2001 Aug 8, 2024
195c5e6
fix: test_monitors
BillHuang2001 Aug 8, 2024
0e2b447
Refactor hooks in the workflow
BillHuang2001 Aug 20, 2024
c826503
Allow parallel monitor
BillHuang2001 Aug 20, 2024
516e070
Make monitor stateful to move most part in the monitor into jit context
BillHuang2001 Aug 20, 2024
dd6038f
Fix typo
BillHuang2001 Aug 22, 2024
7bf3b69
dev: Parallel monitor
BillHuang2001 Sep 4, 2024
1f8bcbb
dev: Remove old monitor code
BillHuang2001 Sep 4, 2024
61582e5
dev: use new monitor code
BillHuang2001 Sep 4, 2024
2a98c7f
test: use the new EvalMonitor and PopMonitor
BillHuang2001 Sep 4, 2024
cd75a2f
dev: DistrbutedPipeline use the new monitor api
BillHuang2001 Sep 4, 2024
c97a41c
dev: standalone clear callbacks method
BillHuang2001 Sep 9, 2024
1bee54c
dev: clear callbacks in the inner state
BillHuang2001 Sep 9, 2024
ce5e6cd
dev: add parallel_init and parallel_step as a sugar
BillHuang2001 Sep 9, 2024
7e2476b
fix: get_best behavior under vmap transforms
BillHuang2001 Sep 9, 2024
bcd30f6
test: update test to use new api
BillHuang2001 Sep 9, 2024
ccdb748
test: lock new target value due to monitor being a stateful causing r…
BillHuang2001 Sep 9, 2024
aa12324
dev: soft delete ray distributed workflow
BillHuang2001 Sep 9, 2024
407d5e8
dev: allow stateful policy
BillHuang2001 Sep 9, 2024
ea09c93
dev: allow stateful nested in list
BillHuang2001 Sep 9, 2024
da03a6c
dev: add custom replace method
BillHuang2001 Sep 23, 2024
8f43f5c
dev: remove parallel step for now
BillHuang2001 Oct 23, 2024
89b56b4
dev: change default dataclass argument
BillHuang2001 Oct 23, 2024
1bef2d8
dev: EvalMonitor use dataclass
BillHuang2001 Oct 23, 2024
7ed05a3
dev: check if both algorithm and problem are dataclasses
BillHuang2001 Oct 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions src/evox/algorithms/so/pso_variants/pso.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,30 @@
# Link: https://ieeexplore.ieee.org/document/494215
# --------------------------------------------------------------------------------------

from functools import partial
from typing import Optional

import jax
import jax.numpy as jnp
import copy

from evox import Algorithm, State, dataclass, pytree_field
from evox.utils import *
from evox import Algorithm, State, jit_class


@jit_class
@dataclass
class PSO(Algorithm):
def __init__(
self,
lb,
ub,
pop_size,
inertia_weight=0.6,
cognitive_coefficient=2.5,
social_coefficient=0.8,
mean=None,
stdev=None,
):
self.dim = lb.shape[0]
self.lb = lb
self.ub = ub
self.pop_size = pop_size
self.w = inertia_weight
self.phi_p = cognitive_coefficient
self.phi_g = social_coefficient
self.mean = mean
self.stdev = stdev
dim: jax.Array = pytree_field(static=True, init=False)
lb: jax.Array
ub: jax.Array
pop_size: jax.Array = pytree_field(static=True)
w: jax.Array = pytree_field(default=0.6)
phi_p: jax.Array = pytree_field(default=2.5)
phi_g: jax.Array = pytree_field(default=0.8)
mean: Optional[jax.Array] = pytree_field(default=None)
stdev: Optional[jax.Array] = pytree_field(default=None)
bound_method: str = pytree_field(static=True, default="clip")

def __post_init__(self):
self.set_frozen_attr("dim", self.lb.shape[0])

def setup(self, key):
state_key, init_pop_key, init_v_key = jax.random.split(key, 3)
Expand Down Expand Up @@ -95,7 +87,23 @@ def tell(self, state, fitness):
+ self.phi_g * rg * (global_best_location - state.population)
)
population = state.population + velocity
population = jnp.clip(population, self.lb, self.ub)

if self.bound_method == "clip":
population = jnp.clip(population, self.lb, self.ub)
elif self.bound_method == "reflect":
lower_bound_violation = population < self.lb
upper_bound_violation = population > self.ub

population = jnp.where(
lower_bound_violation, 2 * self.lb - population, population
)
population = jnp.where(
upper_bound_violation, 2 * self.ub - population, population
)

velocity = jnp.where(
lower_bound_violation | upper_bound_violation, -velocity, velocity
)

return state.replace(
population=population,
Expand Down
70 changes: 60 additions & 10 deletions src/evox/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def wrapper(self, state: State, *args, **kwargs):
new_state,
)

state = state.replace_by_path(path, new_state)
state = state.replace_by_path(
path, new_state.clear_callbacks()
).prepend_closure(new_state)

if aux is None:
return state
Expand Down Expand Up @@ -148,6 +150,10 @@ class Stateful:

The ``init`` method will automatically call the ``setup`` of the current module
and recursively call ``setup`` methods of all submodules.

Currently, there are two special metadata that can be used to control the behavior of the module initialization:
- ``stack``: If set to True, the module will be initialized multiple times, and the states will be stacked together.
- ``nested``: If set to True, the a list of modules, that is [module1, module2, ...], will be iterated and initialized.
"""

def __init__(self):
Expand All @@ -174,10 +180,16 @@ def setup(self, key: jax.Array) -> State:
return State()

def _recursive_init(
self, key: jax.Array, node_id: int, module_name: str, no_state: bool
self,
key: jax.Array,
node_id: int,
module_name: str,
no_state: bool,
re_init: bool,
) -> Tuple[State, int]:
object.__setattr__(self, "_node_id", node_id)
object.__setattr__(self, "_module_name", module_name)
if not re_init:
object.__setattr__(self, "_node_id", node_id)
object.__setattr__(self, "_module_name", module_name)

if not no_state:
child_states = {}
Expand All @@ -197,6 +209,15 @@ def _recursive_init(

if isinstance(attr, Stateful):
submodules.append(SubmoduleInfo(field.name, attr, field.metadata))

# handle "nested" field
if field.metadata.get("nested", False):
for idx, nested_module in enumerate(attr):
submodules.append(
SubmoduleInfo(
field.name + str(idx), nested_module, field.metadata
)
)
else:
for attr_name in vars(self):
attr = getattr(self, attr_name)
Expand All @@ -211,24 +232,27 @@ def _recursive_init(
else:
key, subkey = jax.random.split(key)

# handle "StackAnnotation"
# handle "Stack"
# attr should be a list, or tuple of modules
if metadata.get("stack", False):
num_copies = len(attr)
subkeys = jax.random.split(subkey, num_copies)
current_node_id = node_id
_, node_id = attr._recursive_init(None, node_id + 1, attr_name, True)
_, node_id = attr._recursive_init(
None, node_id + 1, attr_name, True, re_init
)
submodule_state, _node_id = jax.vmap(
partial(
Stateful._recursive_init,
node_id=current_node_id + 1,
module_name=attr_name,
no_state=no_state,
re_init=re_init,
)
)(attr, subkeys)
else:
submodule_state, node_id = attr._recursive_init(
subkey, node_id + 1, attr_name, no_state
subkey, node_id + 1, attr_name, no_state, re_init
)

if not no_state:
Expand All @@ -246,10 +270,12 @@ def _recursive_init(

self_state._set_state_id_mut(self._node_id)._set_child_states_mut(
child_states
),
)
return self_state, node_id

def init(self, key: jax.Array = None, no_state: bool = False) -> State:
def init(
self, key: jax.Array = None, no_state: bool = False, re_init: bool = False
) -> State:
"""Initialize this module and all submodules

This method should not be overwritten.
Expand All @@ -264,9 +290,33 @@ def init(self, key: jax.Array = None, no_state: bool = False) -> State:
State
The state of this module and all submodules combined.
"""
state, _node_id = self._recursive_init(key, 0, None, no_state)
state, _node_id = self._recursive_init(key, 0, None, no_state, re_init)
return state

def parallel_init(
self, key: jax.Array, num_copies: int, no_state: bool = False
) -> Tuple[State, int]:
"""Initialize multiple copies of this module in parallel

This method should not be overwritten.

Parameters
----------
key
A PRNGKey.
num_copies
The number of copies to be initialized
no_state
Whether to skip the state initialization

Returns
-------
Tuple[State, int]
The state of this module and all submodules combined, and the last node_id
"""
subkeys = jax.random.split(key, num_copies)
return jax.vmap(self.init, in_axes=(0, None))(subkeys, no_state)

@classmethod
def stack(cls, stateful_objs, axis=0):
for obj in stateful_objs:
Expand Down
5 changes: 4 additions & 1 deletion src/evox/core/monitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
class Monitor:
from .module import *


class Monitor(Stateful):
"""Monitor base class.
Monitors are used to monitor the evolutionary process.
They contains a set of callbacks,
Expand Down
26 changes: 21 additions & 5 deletions src/evox/core/pytree_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from jax.tree_util import register_pytree_node
import copy
import dataclasses
from typing import Annotated, Any, Callable, Optional, Tuple, TypeVar, get_type_hints

from typing_extensions import (
dataclass_transform, # pytype: disable=not-supported-yet
)
from jax.tree_util import register_pytree_node
from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet

from .distributed import ShardingType

Expand All @@ -19,10 +18,26 @@ def pytree_field(
return dataclasses.field(**kwargs)


def _dataclass_set_frozen_attr(self, key, value):
object.__setattr__(self, key, value)


def _dataclass_replace(self, **kwargs):
"""Add a replace method to dataclasses.
It's different from dataclasses.replace in that it doesn't call the __init__,
instead it copies the object and sets the new values.
"""
new_obj = copy.copy(self)
for key, value in kwargs.items():
object.__setattr__(new_obj, key, value)
return new_obj


def dataclass(cls, *args, **kwargs):
"""
A dataclass decorator that also registers the dataclass as a pytree node.
"""
kwargs = {"unsafe_hash": False, "eq": False, **kwargs}
cls = dataclasses.dataclass(cls, *args, **kwargs)

field_info = []
Expand Down Expand Up @@ -78,7 +93,8 @@ def unflatten(aux_data, children):
register_pytree_node(cls, flatten, unflatten)

# Add a method to set frozen attributes after init
cls.set_frozen_attr = lambda self, key, value: object.__setattr__(self, key, value)
cls.set_frozen_attr = _dataclass_set_frozen_attr
cls.replace = _dataclass_replace
return cls


Expand Down
Loading
Loading