1
+ import importlib .util
1
2
import os
2
3
import sys
3
4
4
- from jax .lib import xla_client
5
-
6
5
__all__ = [
7
- 'import_taichi' ,
8
- 'raise_taichi_not_found' ,
9
- 'import_braintaichi' ,
10
- 'raise_braintaichi_not_found' ,
11
- 'import_numba' ,
12
- 'raise_numba_not_found' ,
13
- 'import_cupy' ,
14
- 'import_cupy_jit' ,
15
- 'raise_cupy_not_found' ,
16
- 'import_brainpylib_cpu_ops' ,
17
- 'import_brainpylib_gpu_ops' ,
6
+ 'import_taichi' ,
7
+ 'import_braintaichi' ,
8
+ 'raise_braintaichi_not_found' ,
18
9
]
19
10
20
- _minimal_brainpylib_version = '0.2.6'
21
- _minimal_taichi_version = (1 , 7 , 2 )
22
-
23
- numba = None
24
11
taichi = None
25
12
braintaichi = None
26
- cupy = None
27
- cupy_jit = None
28
- brainpylib_cpu_ops = None
29
- brainpylib_gpu_ops = None
30
-
31
- taichi_install_info = (f'We need taichi>={ _minimal_taichi_version } . '
32
- f'Currently you can install taichi=={ _minimal_taichi_version } by pip . \n '
33
- '> pip install taichi -U' )
34
- numba_install_info = ('We need numba. Please install numba by pip . \n '
35
- '> pip install numba' )
36
- cupy_install_info = ('We need cupy. Please install cupy by pip . \n '
37
- 'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n '
38
- 'For CUDA v12.x > pip install cupy-cuda12x\n ' )
39
13
braintaichi_install_info = ('We need braintaichi. Please install braintaichi by pip . \n '
40
14
'> pip install braintaichi -U' )
41
15
42
-
43
16
os .environ ["TI_LOG_LEVEL" ] = "error"
44
17
45
18
46
19
def import_taichi (error_if_not_found = True ):
47
- """Internal API to import taichi.
20
+ """Internal API to import taichi.
48
21
49
- If taichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
50
- otherwise it will return None.
51
- """
52
- global taichi
53
- if taichi is None :
54
- with open (os .devnull , 'w' ) as devnull :
55
- old_stdout = sys .stdout
56
- sys .stdout = devnull
57
- try :
58
- import taichi as taichi # noqa
59
- except ModuleNotFoundError :
60
- if error_if_not_found :
61
- raise raise_taichi_not_found ()
62
- finally :
63
- sys .stdout = old_stdout
22
+ If taichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
23
+ otherwise it will return None.
24
+ """
25
+ global taichi
26
+ if taichi is None :
27
+ if importlib .util .find_spec ('taichi' ) is not None :
28
+ with open (os .devnull , 'w' ) as devnull :
29
+ old_stdout = sys .stdout
30
+ sys .stdout = devnull
31
+ try :
32
+ import taichi as taichi # noqa
33
+ except ModuleNotFoundError as e :
34
+ if error_if_not_found :
35
+ raise e
36
+ finally :
37
+ sys .stdout = old_stdout
38
+ else :
39
+ taichi = None
64
40
65
- if taichi is None :
66
- return None
67
- taichi_version = taichi .__version__ [0 ] * 10000 + taichi .__version__ [1 ] * 100 + taichi .__version__ [2 ]
68
- minimal_taichi_version = _minimal_taichi_version [0 ] * 10000 + _minimal_taichi_version [1 ] * 100 + \
69
- _minimal_taichi_version [2 ]
70
- if taichi_version >= minimal_taichi_version :
71
41
return taichi
72
- else :
73
- raise ModuleNotFoundError (taichi_install_info )
74
42
75
43
76
- def raise_taichi_not_found (* args , ** kwargs ):
77
- raise ModuleNotFoundError (taichi_install_info )
78
-
79
44
def import_braintaichi (error_if_not_found = True ):
80
45
"""Internal API to import braintaichi.
81
46
@@ -84,133 +49,18 @@ def import_braintaichi(error_if_not_found=True):
84
49
"""
85
50
global braintaichi
86
51
if braintaichi is None :
87
- try :
88
- import braintaichi as braintaichi
89
- except ModuleNotFoundError :
90
- if error_if_not_found :
91
- raise_braintaichi_not_found ()
52
+ if importlib .util .find_spec ('braintaichi' ) is not None :
53
+ try :
54
+ import braintaichi as braintaichi
55
+ except ModuleNotFoundError :
56
+ if error_if_not_found :
57
+ raise_braintaichi_not_found ()
58
+ else :
59
+ braintaichi = None
92
60
else :
93
- return None
61
+ braintaichi = None
94
62
return braintaichi
95
63
64
+
96
65
def raise_braintaichi_not_found ():
97
66
raise ModuleNotFoundError (braintaichi_install_info )
98
-
99
-
100
- def import_numba (error_if_not_found = True ):
101
- """
102
- Internal API to import numba.
103
-
104
- If numba is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
105
- otherwise it will return None.
106
- """
107
- global numba
108
- if numba is None :
109
- try :
110
- import numba as numba
111
- except ModuleNotFoundError :
112
- if error_if_not_found :
113
- raise_numba_not_found ()
114
- else :
115
- return None
116
- return numba
117
-
118
-
119
- def raise_numba_not_found ():
120
- raise ModuleNotFoundError (numba_install_info )
121
-
122
-
123
- def import_cupy (error_if_not_found = True ):
124
- """
125
- Internal API to import cupy.
126
-
127
- If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
128
- otherwise it will return None.
129
- """
130
- global cupy
131
- if cupy is None :
132
- try :
133
- import cupy as cupy
134
- except ModuleNotFoundError :
135
- if error_if_not_found :
136
- raise_cupy_not_found ()
137
- else :
138
- return None
139
- return cupy
140
-
141
-
142
- def import_cupy_jit (error_if_not_found = True ):
143
- """
144
- Internal API to import cupy.
145
-
146
- If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
147
- otherwise it will return None.
148
- """
149
- global cupy_jit
150
- if cupy_jit is None :
151
- try :
152
- from cupyx import jit as cupy_jit
153
- except ModuleNotFoundError :
154
- if error_if_not_found :
155
- raise_cupy_not_found ()
156
- else :
157
- return None
158
- return cupy_jit
159
-
160
-
161
- def raise_cupy_not_found ():
162
- raise ModuleNotFoundError (cupy_install_info )
163
-
164
-
165
- def is_brainpylib_gpu_installed ():
166
- return False if brainpylib_gpu_ops is None else True
167
-
168
-
169
- def import_brainpylib_cpu_ops ():
170
- """
171
- Internal API to import brainpylib cpu_ops.
172
- """
173
- global brainpylib_cpu_ops
174
- if brainpylib_cpu_ops is None :
175
- try :
176
- from brainpylib import cpu_ops as brainpylib_cpu_ops
177
-
178
- for _name , _value in brainpylib_cpu_ops .registrations ().items ():
179
- xla_client .register_custom_call_target (_name , _value , platform = "cpu" )
180
-
181
- import brainpylib
182
- if brainpylib .__version__ < _minimal_brainpylib_version :
183
- raise SystemError (f'This version of brainpy needs brainpylib >= { _minimal_brainpylib_version } .' )
184
- if hasattr (brainpylib , 'check_brainpy_version' ):
185
- brainpylib .check_brainpy_version ()
186
-
187
- except ImportError :
188
- raise ImportError ('Please install brainpylib. \n '
189
- 'See https://brainpy.readthedocs.io for installation instructions.' )
190
-
191
- return brainpylib_cpu_ops
192
-
193
-
194
- def import_brainpylib_gpu_ops ():
195
- """
196
- Internal API to import brainpylib gpu_ops.
197
- """
198
- global brainpylib_gpu_ops
199
- if brainpylib_gpu_ops is None :
200
- try :
201
- from brainpylib import gpu_ops as brainpylib_gpu_ops
202
-
203
- for _name , _value in brainpylib_gpu_ops .registrations ().items ():
204
- xla_client .register_custom_call_target (_name , _value , platform = "gpu" )
205
-
206
- import brainpylib
207
- if brainpylib .__version__ < _minimal_brainpylib_version :
208
- raise SystemError (f'This version of brainpy needs brainpylib >= { _minimal_brainpylib_version } .' )
209
- if hasattr (brainpylib , 'check_brainpy_version' ):
210
- brainpylib .check_brainpy_version ()
211
-
212
- except ImportError :
213
- raise ImportError ('Please install GPU version of brainpylib. \n '
214
- 'See https://brainpy.readthedocs.io for installation instructions.' )
215
-
216
- return brainpylib_gpu_ops
0 commit comments