Skip to content

Commit c8e1eaa

Browse files
committed
Add cuda support when loading local onnx model
Signed-off-by: David Fan <[email protected]>
1 parent 63f11db commit c8e1eaa

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

setup.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
"invoke>=2.0.0",
3434
"onnx>=1.11.0",
3535
"onnxmltools==1.10.0",
36-
"onnxruntime >=1.10.1;platform_system=='Linux'",
37-
"onnxruntime-directml>=1.19.0;platform_system=='Windows'",
3836
"torch>=1.12.1",
3937
"pyyaml>=5.4",
4038
"typeguard>=2.3.13",
@@ -49,6 +47,8 @@
4947
],
5048
extras_require={
5149
"llm": [
50+
"onnxruntime >=1.10.1;platform_system=='Linux'",
51+
"onnxruntime-directml>=1.19.0;platform_system=='Windows'",
5252
"tqdm",
5353
"torch>=2.0.0",
5454
"transformers",
@@ -60,6 +60,8 @@
6060
"uvicorn[standard]",
6161
],
6262
"llm-oga-dml": [
63+
"onnxruntime >=1.10.1;platform_system=='Linux'",
64+
"onnxruntime-directml>=1.19.0;platform_system=='Windows'",
6365
"onnxruntime-genai-directml==0.4.0",
6466
"tqdm",
6567
"torch>=2.0.0,<2.4",
@@ -71,6 +73,19 @@
7173
"fastapi",
7274
"uvicorn[standard]",
7375
],
76+
"llm-oga-cuda": [
77+
"onnxruntime-gpu>=1.19.0",
78+
"onnxruntime-genai-cuda==0.4.0",
79+
"tqdm",
80+
"torch>=2.0.0,<2.4",
81+
"transformers<4.45.0",
82+
"accelerate",
83+
"py-cpuinfo",
84+
"sentencepiece",
85+
"datasets",
86+
"fastapi",
87+
"uvicorn[standard]",
88+
],
7489
"llm-oga-npu": [
7590
"transformers",
7691
"torch",

src/turnkeyml/llm/tools/ort_genai/oga.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
oga_model_builder_cache_path = "model_builder"
3636

3737
# Mapping from processor to executiion provider, used in pathnames and by model_builder
38-
execution_providers = {"cpu": "cpu", "npu": "npu", "igpu": "dml"}
38+
execution_providers = {"cpu": "cpu", "npu": "npu", "igpu": "dml", "cuda": "cuda"}
3939

4040

4141
class OrtGenaiTokenizer(TokenizerAdapter):
@@ -248,7 +248,7 @@ def parser(add_help: bool = True) -> argparse.ArgumentParser:
248248
parser.add_argument(
249249
"-d",
250250
"--device",
251-
choices=["igpu", "npu", "cpu"],
251+
choices=["igpu", "npu", "cpu", "cuda"],
252252
default="igpu",
253253
help="Which device to load the model on to (default: igpu)",
254254
)
@@ -312,6 +312,7 @@ def run(
312312
"cpu": {"int4": "*/*", "fp32": "*/*"},
313313
"igpu": {"int4": "*/*", "fp16": "*/*"},
314314
"npu": {"int4": "amd/**-onnx-ryzen-strix"},
315+
"cuda": {"int4": "*/*", "fp16": "*/*"},
315316
}
316317
hf_supported = (
317318
device in hf_supported_models

0 commit comments

Comments
 (0)