@@ -42,7 +42,9 @@ def from_pretrained(
42
42
Recipe choices:
43
43
- hf-cpu: Huggingface Transformers implementation for CPU with max-perf settings
44
44
- hf-dgpu: Huggingface Transformers implementation on dGPU (via device="cuda")
45
- - oga-dml: DirectML implementation for iGPU based on onnxruntime-genai
45
+ - oga-cpu: CPU implementation based on onnxruntime-genai
46
+ - oga-dml: DirectML implementation for iGPU based on onnxruntime-genai-directml
47
+ - oga-hybird: AMD Ryzen AI Hybrid implementation based on onnxruntime-genai
46
48
47
49
Returns:
48
50
- model: LLM instance with a generate() method that invokes the recipe
@@ -89,21 +91,28 @@ def from_pretrained(
89
91
90
92
# Make sure the user chose a supported runtime, e.g., oga-cpu
91
93
user_backend = recipe .split ("oga-" )[1 ]
92
- supported_backends = ["cpu" , "igpu" , "npu" , "hybrid" , "cuda" ]
94
+ supported_backends = ["cpu" , "igpu" , "npu" , "hybrid" ]
93
95
supported_recipes = [f"oga-{ backend } " for backend in supported_backends ]
94
96
if recipe not in supported_recipes :
95
97
raise NotSupported (
96
98
"Selected OGA recipe is not supported. "
97
99
f"The supported OGA recipes are: { supported_recipes } "
98
100
)
99
101
102
+ backend_to_dtype = {
103
+ "cpu" : "fp32" ,
104
+ "igpu" : "fp16" ,
105
+ "hybrid" : "int4" ,
106
+ "npu" : "int4" ,
107
+ }
108
+
100
109
state = _make_state (recipe , checkpoint )
101
110
102
111
state = oga .OgaLoad ().run (
103
112
state ,
104
113
input = checkpoint ,
105
114
device = user_backend ,
106
- dtype = "int4" ,
115
+ dtype = backend_to_dtype [ user_backend ] ,
107
116
)
108
117
109
118
return state .model , state .tokenizer
0 commit comments