Skip to content

Commit 589a27c

Browse files
authored
update (#701)
* update * update
1 parent e54b7ff commit 589a27c

34 files changed

+1738
-10565
lines changed

brainpy/_src/dependency_check.py

+33-183
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,46 @@
1+
import importlib.util
12
import os
23
import sys
34

4-
from jax.lib import xla_client
5-
65
__all__ = [
7-
'import_taichi',
8-
'raise_taichi_not_found',
9-
'import_braintaichi',
10-
'raise_braintaichi_not_found',
11-
'import_numba',
12-
'raise_numba_not_found',
13-
'import_cupy',
14-
'import_cupy_jit',
15-
'raise_cupy_not_found',
16-
'import_brainpylib_cpu_ops',
17-
'import_brainpylib_gpu_ops',
6+
'import_taichi',
7+
'import_braintaichi',
8+
'raise_braintaichi_not_found',
189
]
1910

20-
_minimal_brainpylib_version = '0.2.6'
21-
_minimal_taichi_version = (1, 7, 2)
22-
23-
numba = None
2411
taichi = None
2512
braintaichi = None
26-
cupy = None
27-
cupy_jit = None
28-
brainpylib_cpu_ops = None
29-
brainpylib_gpu_ops = None
30-
31-
taichi_install_info = (f'We need taichi>={_minimal_taichi_version}. '
32-
f'Currently you can install taichi=={_minimal_taichi_version} by pip . \n'
33-
'> pip install taichi -U')
34-
numba_install_info = ('We need numba. Please install numba by pip . \n'
35-
'> pip install numba')
36-
cupy_install_info = ('We need cupy. Please install cupy by pip . \n'
37-
'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n'
38-
'For CUDA v12.x > pip install cupy-cuda12x\n')
3913
braintaichi_install_info = ('We need braintaichi. Please install braintaichi by pip . \n'
4014
'> pip install braintaichi -U')
4115

42-
4316
os.environ["TI_LOG_LEVEL"] = "error"
4417

4518

4619
def import_taichi(error_if_not_found=True):
47-
"""Internal API to import taichi.
20+
"""Internal API to import taichi.
4821
49-
If taichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
50-
otherwise it will return None.
51-
"""
52-
global taichi
53-
if taichi is None:
54-
with open(os.devnull, 'w') as devnull:
55-
old_stdout = sys.stdout
56-
sys.stdout = devnull
57-
try:
58-
import taichi as taichi # noqa
59-
except ModuleNotFoundError:
60-
if error_if_not_found:
61-
raise raise_taichi_not_found()
62-
finally:
63-
sys.stdout = old_stdout
22+
If taichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
23+
otherwise it will return None.
24+
"""
25+
global taichi
26+
if taichi is None:
27+
if importlib.util.find_spec('taichi') is not None:
28+
with open(os.devnull, 'w') as devnull:
29+
old_stdout = sys.stdout
30+
sys.stdout = devnull
31+
try:
32+
import taichi as taichi # noqa
33+
except ModuleNotFoundError as e:
34+
if error_if_not_found:
35+
raise e
36+
finally:
37+
sys.stdout = old_stdout
38+
else:
39+
taichi = None
6440

65-
if taichi is None:
66-
return None
67-
taichi_version = taichi.__version__[0] * 10000 + taichi.__version__[1] * 100 + taichi.__version__[2]
68-
minimal_taichi_version = _minimal_taichi_version[0] * 10000 + _minimal_taichi_version[1] * 100 + \
69-
_minimal_taichi_version[2]
70-
if taichi_version >= minimal_taichi_version:
7141
return taichi
72-
else:
73-
raise ModuleNotFoundError(taichi_install_info)
7442

7543

76-
def raise_taichi_not_found(*args, **kwargs):
77-
raise ModuleNotFoundError(taichi_install_info)
78-
7944
def import_braintaichi(error_if_not_found=True):
8045
"""Internal API to import braintaichi.
8146
@@ -84,133 +49,18 @@ def import_braintaichi(error_if_not_found=True):
8449
"""
8550
global braintaichi
8651
if braintaichi is None:
87-
try:
88-
import braintaichi as braintaichi
89-
except ModuleNotFoundError:
90-
if error_if_not_found:
91-
raise_braintaichi_not_found()
52+
if importlib.util.find_spec('braintaichi') is not None:
53+
try:
54+
import braintaichi as braintaichi
55+
except ModuleNotFoundError:
56+
if error_if_not_found:
57+
raise_braintaichi_not_found()
58+
else:
59+
braintaichi = None
9260
else:
93-
return None
61+
braintaichi = None
9462
return braintaichi
9563

64+
9665
def raise_braintaichi_not_found():
9766
raise ModuleNotFoundError(braintaichi_install_info)
98-
99-
100-
def import_numba(error_if_not_found=True):
101-
"""
102-
Internal API to import numba.
103-
104-
If numba is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
105-
otherwise it will return None.
106-
"""
107-
global numba
108-
if numba is None:
109-
try:
110-
import numba as numba
111-
except ModuleNotFoundError:
112-
if error_if_not_found:
113-
raise_numba_not_found()
114-
else:
115-
return None
116-
return numba
117-
118-
119-
def raise_numba_not_found():
120-
raise ModuleNotFoundError(numba_install_info)
121-
122-
123-
def import_cupy(error_if_not_found=True):
124-
"""
125-
Internal API to import cupy.
126-
127-
If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
128-
otherwise it will return None.
129-
"""
130-
global cupy
131-
if cupy is None:
132-
try:
133-
import cupy as cupy
134-
except ModuleNotFoundError:
135-
if error_if_not_found:
136-
raise_cupy_not_found()
137-
else:
138-
return None
139-
return cupy
140-
141-
142-
def import_cupy_jit(error_if_not_found=True):
143-
"""
144-
Internal API to import cupy.
145-
146-
If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
147-
otherwise it will return None.
148-
"""
149-
global cupy_jit
150-
if cupy_jit is None:
151-
try:
152-
from cupyx import jit as cupy_jit
153-
except ModuleNotFoundError:
154-
if error_if_not_found:
155-
raise_cupy_not_found()
156-
else:
157-
return None
158-
return cupy_jit
159-
160-
161-
def raise_cupy_not_found():
162-
raise ModuleNotFoundError(cupy_install_info)
163-
164-
165-
def is_brainpylib_gpu_installed():
166-
return False if brainpylib_gpu_ops is None else True
167-
168-
169-
def import_brainpylib_cpu_ops():
170-
"""
171-
Internal API to import brainpylib cpu_ops.
172-
"""
173-
global brainpylib_cpu_ops
174-
if brainpylib_cpu_ops is None:
175-
try:
176-
from brainpylib import cpu_ops as brainpylib_cpu_ops
177-
178-
for _name, _value in brainpylib_cpu_ops.registrations().items():
179-
xla_client.register_custom_call_target(_name, _value, platform="cpu")
180-
181-
import brainpylib
182-
if brainpylib.__version__ < _minimal_brainpylib_version:
183-
raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.')
184-
if hasattr(brainpylib, 'check_brainpy_version'):
185-
brainpylib.check_brainpy_version()
186-
187-
except ImportError:
188-
raise ImportError('Please install brainpylib. \n'
189-
'See https://brainpy.readthedocs.io for installation instructions.')
190-
191-
return brainpylib_cpu_ops
192-
193-
194-
def import_brainpylib_gpu_ops():
195-
"""
196-
Internal API to import brainpylib gpu_ops.
197-
"""
198-
global brainpylib_gpu_ops
199-
if brainpylib_gpu_ops is None:
200-
try:
201-
from brainpylib import gpu_ops as brainpylib_gpu_ops
202-
203-
for _name, _value in brainpylib_gpu_ops.registrations().items():
204-
xla_client.register_custom_call_target(_name, _value, platform="gpu")
205-
206-
import brainpylib
207-
if brainpylib.__version__ < _minimal_brainpylib_version:
208-
raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.')
209-
if hasattr(brainpylib, 'check_brainpy_version'):
210-
brainpylib.check_brainpy_version()
211-
212-
except ImportError:
213-
raise ImportError('Please install GPU version of brainpylib. \n'
214-
'See https://brainpy.readthedocs.io for installation instructions.')
215-
216-
return brainpylib_gpu_ops

0 commit comments

Comments
 (0)