forked from deepmodeling/deepmd-kit
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathfind_pytorch.py
162 lines (140 loc) · 4.69 KB
/
find_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# SPDX-License-Identifier: LGPL-3.0-or-later
import importlib
import os
import platform
import site
from functools import (
lru_cache,
)
from importlib.machinery import (
FileFinder,
)
from importlib.util import (
find_spec,
)
from pathlib import (
Path,
)
from sysconfig import (
get_path,
)
from typing import (
Optional,
Union,
)
from packaging.specifiers import (
SpecifierSet,
)
from packaging.version import (
Version,
)
@lru_cache
def find_pytorch() -> tuple[Optional[str], list[str]]:
"""Find PyTorch library.
Tries to find PyTorch in the order of:
1. Environment variable `PYTORCH_ROOT` if set
2. The current Python environment.
3. user site packages directory if enabled
4. system site packages directory (purelib)
Considering the default PyTorch package still uses old CXX11 ABI, we
cannot install it automatically.
Returns
-------
str, optional
PyTorch library path if found.
list of str
TensorFlow requirement if not found. Empty if found.
"""
if os.environ.get("DP_ENABLE_PYTORCH", "0") == "0":
return None, []
requires = []
pt_spec = None
if (pt_spec is None or not pt_spec) and os.environ.get("PYTORCH_ROOT") is not None:
site_packages = Path(os.environ.get("PYTORCH_ROOT")).parent.absolute()
pt_spec = FileFinder(str(site_packages)).find_spec("torch")
# get pytorch spec
# note: isolated build will not work for backend
if pt_spec is None or not pt_spec:
pt_spec = find_spec("torch")
if not pt_spec and site.ENABLE_USER_SITE:
# first search TF from user site-packages before global site-packages
site_packages = site.getusersitepackages()
if site_packages:
pt_spec = FileFinder(site_packages).find_spec("torch")
if not pt_spec:
# purelib gets site-packages path
site_packages = get_path("purelib")
if site_packages:
pt_spec = FileFinder(site_packages).find_spec("torch")
# get install dir from spec
try:
pt_install_dir = pt_spec.submodule_search_locations[0] # type: ignore
# AttributeError if ft_spec is None
# TypeError if submodule_search_locations are None
# IndexError if submodule_search_locations is an empty list
except (AttributeError, TypeError, IndexError):
pt_install_dir = None
requires.extend(get_pt_requirement()["torch"])
return pt_install_dir, requires
@lru_cache
def get_pt_requirement(pt_version: str = "") -> dict:
"""Get PyTorch requirement when PT is not installed.
If pt_version is not given and the environment variable `PYTORCH_VERSION` is set, use it as the requirement.
Parameters
----------
pt_version : str, optional
PT version
Returns
-------
dict
PyTorch requirement.
"""
if pt_version is None:
return {"torch": []}
if (
os.environ.get("CIBUILDWHEEL", "0") == "1"
and platform.system() == "Linux"
and platform.machine() == "x86_64"
):
cuda_version = os.environ.get("CUDA_VERSION", "12.2")
if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"):
# CUDA 12.2, cudnn 9
pt_version = "2.6.0"
elif cuda_version in SpecifierSet(">=11,<12"):
# CUDA 11.8, cudnn 8
pt_version = "2.3.1"
else:
raise RuntimeError("Unsupported CUDA version") from None
if pt_version == "":
pt_version = os.environ.get("PYTORCH_VERSION", "")
return {
"torch": [
# uv has different local version behaviors, i.e. `==2.3.1` cannot match `==2.3.1+cpu`
# https://github.com/astral-sh/uv/blob/main/PIP_COMPATIBILITY.md#local-version-identifiers
# luckily, .* (prefix matching) defined in PEP 440 can match any local version
# https://peps.python.org/pep-0440/#version-matching
f"torch=={Version(pt_version).base_version}.*"
if pt_version != ""
# https://github.com/pytorch/pytorch/commit/7e0c26d4d80d6602aed95cb680dfc09c9ce533bc
else "torch>=2.1.0"
],
}
@lru_cache
def get_pt_version(pt_path: Optional[Union[str, Path]]) -> str:
"""Get TF version from a TF Python library path.
Parameters
----------
pt_path : str or Path
PT Python library path
Returns
-------
str
version
"""
if pt_path is None or pt_path == "":
return ""
version_file = Path(pt_path) / "version.py"
spec = importlib.util.spec_from_file_location("torch.version", version_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.__version__