12
12
from brainpy import math as bm
13
13
from brainpy ._src import connect , initialize as init
14
14
from brainpy ._src .context import share
15
- from brainpy ._src .dependency_check import import_taichi , import_braintaichi
15
+ from brainpy ._src .dependency_check import import_taichi , import_braintaichi , raise_braintaichi_not_found
16
16
from brainpy ._src .dnn .base import Layer
17
17
from brainpy ._src .mixin import SupportOnline , SupportOffline , SupportSTDP
18
18
from brainpy .check import is_initializer
@@ -241,7 +241,7 @@ def update(self, x):
241
241
return x
242
242
243
243
244
- if ti is not None :
244
+ if ti is not None and bti is not None :
245
245
246
246
# @numba.njit(nogil=True, fastmath=True, parallel=False)
247
247
# def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w):
@@ -321,7 +321,7 @@ def _dense_on_pre(
321
321
322
322
def dense_on_pre (weight , spike , trace , w_min , w_max ):
323
323
if dense_on_pre_prim is None :
324
- raise PackageMissingError . by_purpose ( 'taichi' , 'custom operators' )
324
+ raise_braintaichi_not_found ( )
325
325
326
326
if w_min is None :
327
327
w_min = - np .inf
@@ -341,7 +341,7 @@ def dense_on_pre(weight, spike, trace, w_min, w_max):
341
341
342
342
def dense_on_post (weight , spike , trace , w_min , w_max ):
343
343
if dense_on_post_prim is None :
344
- raise PackageMissingError . by_purpose ( 'taichi' , 'custom operators' )
344
+ raise_braintaichi_not_found ( )
345
345
346
346
if w_min is None :
347
347
w_min = - np .inf
@@ -728,7 +728,7 @@ def _batch_csrmv(self, x):
728
728
transpose = self .transpose )
729
729
730
730
731
- if ti is not None :
731
+ if ti is not None and bti is not None :
732
732
@ti .kernel
733
733
def _csr_on_pre_update (
734
734
old_w : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_syn)
@@ -852,7 +852,7 @@ def _csc_on_post_update(
852
852
853
853
def csr_on_pre_update (w , indices , indptr , spike , trace , w_min = None , w_max = None ):
854
854
if csr_on_pre_update_prim is None :
855
- raise PackageMissingError . by_purpose ( 'taichi' , 'customized operators' )
855
+ raise_braintaichi_not_found ( )
856
856
857
857
if w_min is None :
858
858
w_min = - np .inf
@@ -874,7 +874,7 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
874
874
875
875
def coo_on_pre_update (w , pre_ids , post_ids , spike , trace , w_min = None , w_max = None ):
876
876
if coo_on_pre_update_prim is None :
877
- raise PackageMissingError . by_purpose ( 'taichi' , 'customized operators' )
877
+ raise_braintaichi_not_found ( )
878
878
879
879
if w_min is None :
880
880
w_min = - np .inf
@@ -897,7 +897,7 @@ def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None
897
897
898
898
def csc_on_post_update (w , post_ids , indptr , w_ids , post_spike , pre_trace , w_min = None , w_max = None ):
899
899
if csc_on_post_update_prim is None :
900
- raise PackageMissingError . by_purpose ( 'taichi' , 'customized operators' )
900
+ raise_braintaichi_not_found ( )
901
901
902
902
if w_min is None :
903
903
w_min = - np .inf
0 commit comments