From 7b08581ca4eab7abb827fb9fc9a25beb8d64ee11 Mon Sep 17 00:00:00 2001 From: routhleck Date: Sun, 19 Jan 2025 20:45:33 +0800 Subject: [PATCH] Support brainunit --- .../object_transform/tests/test_variable.py | 51 ++++++++++++++++++- .../_src/math/object_transform/variables.py | 16 ++++-- requirements-dev.txt | 1 + setup.py | 2 +- 4 files changed, 65 insertions(+), 5 deletions(-) diff --git a/brainpy/_src/math/object_transform/tests/test_variable.py b/brainpy/_src/math/object_transform/tests/test_variable.py index ddf7c8d22..3482a07a3 100644 --- a/brainpy/_src/math/object_transform/tests/test_variable.py +++ b/brainpy/_src/math/object_transform/tests/test_variable.py @@ -1,9 +1,12 @@ import brainpy.math as bm +import brainunit as u +import jax.numpy as jnp +from functools import partial import unittest class TestVar(unittest.TestCase): - def test1(self): + def test_ndarray(self): class A(bm.BrainPyObject): def __init__(self): super().__init__() @@ -33,6 +36,8 @@ def fff(self): print() a = A() + temp = a.f1() + print(temp) self.assertTrue(bm.all(a.f1() == 2.)) self.assertTrue(len(a.f1._dyn_vars) == 2) print(a.f2()) @@ -46,6 +51,50 @@ def fff(self): bm.clear_buffer_memory() + def test_state(self): + class B(bm.BrainPyObject): + def __init__(self): + super().__init__() + self.a = bm.Variable([0.,] * u.mV) + self.f1 = bm.jit(self.f) + self.f2 = bm.jit(self.ff) + self.f3 = bm.jit(self.fff) + + def f(self): + ones_fun = partial(u.math.ones,unit=u.mV) + b = self.tracing_variable('b', ones_fun, (1,)) + self.a += (b * 2) + return self.a.value + + def ff(self): + self.b += 1. * u.mV + + def fff(self): + self.f() + self.ff() + self.b *= self.a.value.mantissa + return self.b.value + + print() + f_jit = bm.jit(B().f) + f_jit() + self.assertTrue(len(f_jit._dyn_vars) == 2) + + print() + b = B() + self.assertTrue(u.math.all(b.f1() == [2.,] * u.mV)) + self.assertTrue(len(b.f1._dyn_vars) == 2) + print(b.f2()) + self.assertTrue(len(b.f2._dyn_vars) == 1) + + print() + b = B() + print() + self.assertTrue(u.math.allclose(b.f3(), 4. * u.mV)) + self.assertTrue(len(b.f3._dyn_vars) == 2) + + bm.clear_buffer_memory() + diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index b7babae8d..2988986bf 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -7,6 +7,8 @@ from jax.tree_util import register_pytree_node_class from brainpy._src.math.ndarray import Array +from brainstate import State +from brainunit import Quantity from brainpy._src.math.sharding import BATCH_AXIS from brainpy.errors import MathError @@ -220,7 +222,7 @@ def __add__(self, other: dict): @register_pytree_node_class -class Variable(Array): +class Variable(Array, State): """The pointer to specify the dynamical variable. Initializing an instance of ``Variable`` by two ways: @@ -250,7 +252,8 @@ def __init__( batch_axis: int = None, *, axis_names: Optional[Sequence[str]] = None, - ready_to_trace: bool = None + ready_to_trace: bool = None, + state_mode: bool = False, ): if isinstance(value_or_size, int): value = jnp.zeros(value_or_size, dtype=dtype) @@ -259,7 +262,14 @@ def __init__( else: value = value_or_size - super().__init__(value, dtype=dtype) + if isinstance(value, Quantity): + state_mode = True + + if state_mode: + State.__init__(self, value, dtype=dtype) + self._value = value + else: + Array.__init__(self, value, dtype=dtype) # check batch axis if isinstance(value, Variable): diff --git a/requirements-dev.txt b/requirements-dev.txt index eb6e5a552..dd05923b1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,6 +8,7 @@ pathos braintaichi numba brainstate +brainunit braintools setuptools diff --git a/setup.py b/setup.py index e76727d70..86ee8d13e 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.9', - install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm'], + install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm', 'brainstate', 'brainunit'], url='https://github.com/brainpy/BrainPy', project_urls={ "Bug Tracker": "https://github.com/brainpy/BrainPy/issues",