Skip to content

Commit 2bb2f9e

Browse files
authored
[dnn] Improve error handling in dnn/linear module (#704)
Improve error handling in `dnn/linear` module
1 parent a354bf2 commit 2bb2f9e

File tree

3 files changed

+9
-33
lines changed

3 files changed

+9
-33
lines changed

brainpy/_src/dnn/linear.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from brainpy import math as bm
1313
from brainpy._src import connect, initialize as init
1414
from brainpy._src.context import share
15-
from brainpy._src.dependency_check import import_taichi, import_braintaichi
15+
from brainpy._src.dependency_check import import_taichi, import_braintaichi, raise_braintaichi_not_found
1616
from brainpy._src.dnn.base import Layer
1717
from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP
1818
from brainpy.check import is_initializer
@@ -241,7 +241,7 @@ def update(self, x):
241241
return x
242242

243243

244-
if ti is not None:
244+
if ti is not None and bti is not None:
245245

246246
# @numba.njit(nogil=True, fastmath=True, parallel=False)
247247
# def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w):
@@ -321,7 +321,7 @@ def _dense_on_pre(
321321

322322
def dense_on_pre(weight, spike, trace, w_min, w_max):
323323
if dense_on_pre_prim is None:
324-
raise PackageMissingError.by_purpose('taichi', 'custom operators')
324+
raise_braintaichi_not_found()
325325

326326
if w_min is None:
327327
w_min = -np.inf
@@ -341,7 +341,7 @@ def dense_on_pre(weight, spike, trace, w_min, w_max):
341341

342342
def dense_on_post(weight, spike, trace, w_min, w_max):
343343
if dense_on_post_prim is None:
344-
raise PackageMissingError.by_purpose('taichi', 'custom operators')
344+
raise_braintaichi_not_found()
345345

346346
if w_min is None:
347347
w_min = -np.inf
@@ -728,7 +728,7 @@ def _batch_csrmv(self, x):
728728
transpose=self.transpose)
729729

730730

731-
if ti is not None:
731+
if ti is not None and bti is not None:
732732
@ti.kernel
733733
def _csr_on_pre_update(
734734
old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
@@ -852,7 +852,7 @@ def _csc_on_post_update(
852852

853853
def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
854854
if csr_on_pre_update_prim is None:
855-
raise PackageMissingError.by_purpose('taichi', 'customized operators')
855+
raise_braintaichi_not_found()
856856

857857
if w_min is None:
858858
w_min = -np.inf
@@ -874,7 +874,7 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
874874

875875
def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None):
876876
if coo_on_pre_update_prim is None:
877-
raise PackageMissingError.by_purpose('taichi', 'customized operators')
877+
raise_braintaichi_not_found()
878878

879879
if w_min is None:
880880
w_min = -np.inf
@@ -897,7 +897,7 @@ def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None
897897

898898
def csc_on_post_update(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min=None, w_max=None):
899899
if csc_on_post_update_prim is None:
900-
raise PackageMissingError.by_purpose('taichi', 'customized operators')
900+
raise_braintaichi_not_found()
901901

902902
if w_min is None:
903903
w_min = -np.inf

docker/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ jax
66
jaxlib
77
scipy>=1.1.0
88
brainpy
9-
brainpylib
109
brainpy_datasets
1110
h5py
1211
pathos
12+
braintaichi
1313

1414
# test requirements
1515
pytest

docs/quickstart/installation.rst

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ To install brainpy with minimum requirements (only depends on ``jax``), you can
3636
3737
# or
3838
39-
pip install brainpy[cuda11_mini] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 11.0
4039
pip install brainpy[cuda12_mini] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 12.0
4140
4241
# or
@@ -64,7 +63,6 @@ To install a GPU-only version of BrainPy, you can run
6463
.. code-block:: bash
6564
6665
pip install brainpy[cuda12] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 12.0
67-
pip install brainpy[cuda11] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 11.0
6866
6967
7068
@@ -79,25 +77,3 @@ you can run the following in your cloud TPU VM:
7977
pip install brainpy[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # for google TPU
8078
8179
82-
83-
``brainpylib``
84-
--------------
85-
86-
87-
``brainpylib`` defines a set of useful operators for building and simulating spiking neural networks.
88-
89-
90-
To install the ``brainpylib`` package on CPU devices, you can run
91-
92-
.. code-block:: bash
93-
94-
pip install brainpylib
95-
96-
97-
To install the ``brainpylib`` package on CUDA (Linux only), you can run
98-
99-
100-
.. code-block:: bash
101-
102-
pip install brainpylib
103-

0 commit comments

Comments
 (0)