8
8
published_version.json file
9
9
"""
10
10
11
- import json
12
- import copy
13
11
import argparse
12
+ import copy
13
+ import json
14
+ from enum import Enum
14
15
from pathlib import Path
15
16
from typing import Dict
16
- from enum import Enum
17
17
18
- BASE_DIR = Path (__file__ ).parent .parent
18
+ BASE_DIR = Path (__file__ ).parent .parent
19
+
19
20
20
21
class OperatingSystem (Enum ):
21
22
LINUX : str = "linux"
22
23
WINDOWS : str = "windows"
23
24
MACOS : str = "macos"
24
25
26
+
25
27
PRE_CXX11_ABI = "pre-cxx11"
26
28
CXX11_ABI = "cxx11-abi"
27
29
DEBUG = "debug"
@@ -38,29 +40,30 @@ class OperatingSystem(Enum):
38
40
"cuda.x" : ("cuda" , "11.8" ),
39
41
"cuda.y" : ("cuda" , "12.1" ),
40
42
"cuda.z" : ("cuda" , "12.4" ),
41
- "rocm5.x" : ("rocm" , "6.0" )
42
- },
43
+ "rocm5.x" : ("rocm" , "6.0" ),
44
+ },
43
45
"release" : {
44
46
"accnone" : ("cpu" , "" ),
45
47
"cuda.x" : ("cuda" , "11.8" ),
46
48
"cuda.y" : ("cuda" , "12.1" ),
47
49
"cuda.z" : ("cuda" , "12.4" ),
48
- "rocm5.x" : ("rocm" , "6.0" )
49
- }
50
- }
50
+ "rocm5.x" : ("rocm" , "6.0" ),
51
+ },
52
+ }
51
53
52
54
# Initialize arch version to default values
53
55
# these default values will be overwritten by
54
56
# extracted values from the release marix
55
57
acc_arch_ver_map = acc_arch_ver_default
56
58
57
59
LIBTORCH_DWNL_INSTR = {
58
- PRE_CXX11_ABI : "Download here (Pre-cxx11 ABI):" ,
59
- CXX11_ABI : "Download here (cxx11 ABI):" ,
60
- RELEASE : "Download here (Release version):" ,
61
- DEBUG : "Download here (Debug version):" ,
62
- MACOS : "Download arm64 libtorch here (ROCm and CUDA are not supported):" ,
63
- }
60
+ PRE_CXX11_ABI : "Download here (Pre-cxx11 ABI):" ,
61
+ CXX11_ABI : "Download here (cxx11 ABI):" ,
62
+ RELEASE : "Download here (Release version):" ,
63
+ DEBUG : "Download here (Debug version):" ,
64
+ MACOS : "Download arm64 libtorch here (ROCm and CUDA are not supported):" ,
65
+ }
66
+
64
67
65
68
def load_json_from_basedir (filename : str ):
66
69
try :
@@ -71,32 +74,39 @@ def load_json_from_basedir(filename: str):
71
74
except json .JSONDecodeError as exc :
72
75
raise ImportError (f"Invalid JSON { filename } " ) from exc
73
76
77
+
74
78
def read_published_versions ():
75
79
return load_json_from_basedir ("published_versions.json" )
76
80
81
+
77
82
def write_published_versions (versions ):
78
83
with open (BASE_DIR / "published_versions.json" , "w" ) as outfile :
79
84
json .dump (versions , outfile , indent = 2 )
80
85
86
+
81
87
def read_matrix_for_os (osys : OperatingSystem , channel : str ):
82
88
jsonfile = load_json_from_basedir (f"{ osys .value } _{ channel } _matrix.json" )
83
89
return jsonfile ["include" ]
84
90
91
+
85
92
def read_quick_start_module_template ():
86
93
with open (BASE_DIR / "_includes" / "quick-start-module.js" ) as fptr :
87
94
return fptr .read ()
88
95
96
+
89
97
def get_package_type (pkg_key : str , os_key : OperatingSystem ) -> str :
90
98
if pkg_key != "pip" :
91
99
return pkg_key
92
100
return "manywheel" if os_key == OperatingSystem .LINUX .value else "wheel"
93
101
102
+
94
103
def get_gpu_info (acc_key , instr , acc_arch_map ):
95
104
gpu_arch_type , gpu_arch_version = acc_arch_map [acc_key ]
96
105
if DEFAULT in instr :
97
106
gpu_arch_type , gpu_arch_version = acc_arch_map ["accnone" ]
98
107
return (gpu_arch_type , gpu_arch_version )
99
108
109
+
100
110
# This method is used for generating new published_versions.json file
101
111
# It will modify versions json object with installation instructions
102
112
# Provided by generate install matrix Github Workflow, stored in release_matrix
@@ -109,42 +119,62 @@ def update_versions(versions, release_matrix, release_version):
109
119
if release_version != "nightly" :
110
120
version = release_matrix [OperatingSystem .LINUX .value ][0 ]["stable_version" ]
111
121
if version not in versions ["versions" ]:
112
- versions ["versions" ][version ] = copy .deepcopy (versions ["versions" ][template ])
122
+ versions ["versions" ][version ] = copy .deepcopy (
123
+ versions ["versions" ][template ]
124
+ )
113
125
versions ["latest_stable" ] = version
114
126
115
127
# Perform update of the json file from release matrix
116
128
for os_key , os_vers in versions ["versions" ][version ].items ():
117
129
for pkg_key , pkg_vers in os_vers .items ():
118
130
for acc_key , instr in pkg_vers .items ():
119
131
package_type = get_package_type (pkg_key , os_key )
120
- gpu_arch_type , gpu_arch_version = get_gpu_info (acc_key , instr , acc_arch_map )
132
+ gpu_arch_type , gpu_arch_version = get_gpu_info (
133
+ acc_key , instr , acc_arch_map
134
+ )
121
135
122
136
pkg_arch_matrix = [
123
- x for x in release_matrix [os_key ]
124
- if (x ["package_type" ], x ["gpu_arch_type" ], x ["gpu_arch_version" ]) ==
125
- (package_type , gpu_arch_type , gpu_arch_version )
126
- ]
137
+ x
138
+ for x in release_matrix [os_key ]
139
+ if (x ["package_type" ], x ["gpu_arch_type" ], x ["gpu_arch_version" ])
140
+ == (package_type , gpu_arch_type , gpu_arch_version )
141
+ ]
127
142
128
143
if pkg_arch_matrix :
129
144
if package_type != "libtorch" :
130
145
instr ["command" ] = pkg_arch_matrix [0 ]["installation" ]
131
146
else :
132
147
if os_key == OperatingSystem .LINUX .value :
133
148
rel_entry_dict = {
134
- x ["devtoolset" ]: x ["installation" ] for x in pkg_arch_matrix
149
+ x ["devtoolset" ]: x ["installation" ]
150
+ for x in pkg_arch_matrix
135
151
if x ["libtorch_variant" ] == "shared-with-deps"
136
- }
152
+ }
137
153
if instr ["versions" ] is not None :
138
154
for ver in [PRE_CXX11_ABI , CXX11_ABI ]:
139
- instr ["versions" ][LIBTORCH_DWNL_INSTR [ver ]] = rel_entry_dict [ver ]
155
+ if gpu_arch_type == "rocm" and ver == PRE_CXX11_ABI :
156
+ continue
157
+ else :
158
+ instr ["versions" ][LIBTORCH_DWNL_INSTR [ver ]] = (
159
+ rel_entry_dict [ver ]
160
+ )
161
+
140
162
elif os_key == OperatingSystem .WINDOWS .value :
141
- rel_entry_dict = {x ["libtorch_config" ]: x ["installation" ] for x in pkg_arch_matrix }
163
+ rel_entry_dict = {
164
+ x ["libtorch_config" ]: x ["installation" ]
165
+ for x in pkg_arch_matrix
166
+ }
142
167
if instr ["versions" ] is not None :
143
168
for ver in [RELEASE , DEBUG ]:
144
- instr ["versions" ][LIBTORCH_DWNL_INSTR [ver ]] = rel_entry_dict [ver ]
169
+ instr ["versions" ][LIBTORCH_DWNL_INSTR [ver ]] = (
170
+ rel_entry_dict [ver ]
171
+ )
145
172
elif os_key == OperatingSystem .MACOS .value :
146
173
if instr ["versions" ] is not None :
147
- instr ["versions" ][LIBTORCH_DWNL_INSTR [MACOS ]] = pkg_arch_matrix [0 ]["installation" ]
174
+ instr ["versions" ][LIBTORCH_DWNL_INSTR [MACOS ]] = (
175
+ pkg_arch_matrix [0 ]["installation" ]
176
+ )
177
+
148
178
149
179
# This method is used for generating new quick-start-module.js
150
180
# from the versions json object
@@ -158,21 +188,25 @@ def gen_install_matrix(versions) -> Dict[str, str]:
158
188
for os_key , os_vers in versions ["versions" ][ver_key ].items ():
159
189
for pkg_key , pkg_vers in os_vers .items ():
160
190
for acc_key , instr in pkg_vers .items ():
161
- extra_key = ' python' if pkg_key != ' libtorch' else ' cplusplus'
191
+ extra_key = " python" if pkg_key != " libtorch" else " cplusplus"
162
192
key = f"{ ver } ,{ pkg_key } ,{ os_key } ,{ acc_key } ,{ extra_key } "
163
193
note = instr ["note" ]
164
194
lines = [note ] if note is not None else []
165
195
if pkg_key == "libtorch" :
166
196
ivers = instr ["versions" ]
167
197
if ivers is not None :
168
- lines += [f"{ lab } <br /><a href='{ val } '>{ val } </a>" for (lab , val ) in ivers .items ()]
198
+ lines += [
199
+ f"{ lab } <br /><a href='{ val } '>{ val } </a>"
200
+ for (lab , val ) in ivers .items ()
201
+ ]
169
202
else :
170
203
command = instr ["command" ]
171
204
if command is not None :
172
205
lines .append (command )
173
206
result [key ] = "<br />" .join (lines )
174
207
return result
175
208
209
+
176
210
# This method is used for extracting two latest verisons of cuda and
177
211
# last verion of rocm. It will modify the acc_arch_ver_map object used
178
212
# to update getting started page.
@@ -195,7 +229,7 @@ def gen_ver_list(chan, gpu_arch_type):
195
229
196
230
def main ():
197
231
parser = argparse .ArgumentParser ()
198
- parser .add_argument (' --autogenerate' , dest = ' autogenerate' , action = ' store_true' )
232
+ parser .add_argument (" --autogenerate" , dest = " autogenerate" , action = " store_true" )
199
233
parser .set_defaults (autogenerate = True )
200
234
201
235
options = parser .parse_args ()
@@ -217,8 +251,11 @@ def main():
217
251
template = read_quick_start_module_template ()
218
252
versions_str = json .dumps (gen_install_matrix (versions ))
219
253
template = template .replace ("{{ installMatrix }}" , versions_str )
220
- template = template .replace ("{{ VERSION }}" , f"\" Stable ({ versions ['latest_stable' ]} )\" " )
254
+ template = template .replace (
255
+ "{{ VERSION }}" , f"\" Stable ({ versions ['latest_stable' ]} )\" "
256
+ )
221
257
print (template .replace ("{{ ACC ARCH MAP }}" , json .dumps (acc_arch_ver_map )))
222
258
259
+
223
260
if __name__ == "__main__" :
224
261
main ()
0 commit comments