Skip to content

Commit

Permalink
Support brainunit
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jan 19, 2025
1 parent a08ad48 commit 7b08581
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 5 deletions.
51 changes: 50 additions & 1 deletion brainpy/_src/math/object_transform/tests/test_variable.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import brainpy.math as bm
import brainunit as u
import jax.numpy as jnp
from functools import partial
import unittest


class TestVar(unittest.TestCase):
def test1(self):
def test_ndarray(self):
class A(bm.BrainPyObject):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -33,6 +36,8 @@ def fff(self):

print()
a = A()
temp = a.f1()
print(temp)
self.assertTrue(bm.all(a.f1() == 2.))
self.assertTrue(len(a.f1._dyn_vars) == 2)
print(a.f2())
Expand All @@ -46,6 +51,50 @@ def fff(self):

bm.clear_buffer_memory()

def test_state(self):
class B(bm.BrainPyObject):
def __init__(self):
super().__init__()
self.a = bm.Variable([0.,] * u.mV)
self.f1 = bm.jit(self.f)
self.f2 = bm.jit(self.ff)
self.f3 = bm.jit(self.fff)

def f(self):
ones_fun = partial(u.math.ones,unit=u.mV)
b = self.tracing_variable('b', ones_fun, (1,))
self.a += (b * 2)
return self.a.value

def ff(self):
self.b += 1. * u.mV

def fff(self):
self.f()
self.ff()
self.b *= self.a.value.mantissa
return self.b.value

print()
f_jit = bm.jit(B().f)
f_jit()
self.assertTrue(len(f_jit._dyn_vars) == 2)

print()
b = B()
self.assertTrue(u.math.all(b.f1() == [2.,] * u.mV))
self.assertTrue(len(b.f1._dyn_vars) == 2)
print(b.f2())
self.assertTrue(len(b.f2._dyn_vars) == 1)

print()
b = B()
print()
self.assertTrue(u.math.allclose(b.f3(), 4. * u.mV))
self.assertTrue(len(b.f3._dyn_vars) == 2)

bm.clear_buffer_memory()




16 changes: 13 additions & 3 deletions brainpy/_src/math/object_transform/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from jax.tree_util import register_pytree_node_class

from brainpy._src.math.ndarray import Array
from brainstate import State
from brainunit import Quantity
from brainpy._src.math.sharding import BATCH_AXIS
from brainpy.errors import MathError

Expand Down Expand Up @@ -220,7 +222,7 @@ def __add__(self, other: dict):


@register_pytree_node_class
class Variable(Array):
class Variable(Array, State):
"""The pointer to specify the dynamical variable.
Initializing an instance of ``Variable`` by two ways:
Expand Down Expand Up @@ -250,7 +252,8 @@ def __init__(
batch_axis: int = None,
*,
axis_names: Optional[Sequence[str]] = None,
ready_to_trace: bool = None
ready_to_trace: bool = None,
state_mode: bool = False,
):
if isinstance(value_or_size, int):
value = jnp.zeros(value_or_size, dtype=dtype)
Expand All @@ -259,7 +262,14 @@ def __init__(
else:
value = value_or_size

super().__init__(value, dtype=dtype)
if isinstance(value, Quantity):
state_mode = True

if state_mode:
State.__init__(self, value, dtype=dtype)
self._value = value
else:
Array.__init__(self, value, dtype=dtype)

# check batch axis
if isinstance(value, Variable):
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pathos
braintaichi
numba
brainstate
brainunit
braintools
setuptools

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
author_email='[email protected]',
packages=packages,
python_requires='>=3.9',
install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm'],
install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm', 'brainstate', 'brainunit'],
url='https://github.com/brainpy/BrainPy',
project_urls={
"Bug Tracker": "https://github.com/brainpy/BrainPy/issues",
Expand Down

0 comments on commit 7b08581

Please sign in to comment.