@@ -42,7 +42,9 @@ def from_pretrained(
4242 Recipe choices:
4343 - hf-cpu: Huggingface Transformers implementation for CPU with max-perf settings
4444 - hf-dgpu: Huggingface Transformers implementation on dGPU (via device="cuda")
45- - oga-dml: DirectML implementation for iGPU based on onnxruntime-genai
45+ - oga-cpu: CPU implementation based on onnxruntime-genai
46+ - oga-dml: DirectML implementation for iGPU based on onnxruntime-genai-directml
47+ - oga-hybird: AMD Ryzen AI Hybrid implementation based on onnxruntime-genai
4648
4749 Returns:
4850 - model: LLM instance with a generate() method that invokes the recipe
@@ -89,21 +91,28 @@ def from_pretrained(
8991
9092 # Make sure the user chose a supported runtime, e.g., oga-cpu
9193 user_backend = recipe .split ("oga-" )[1 ]
92- supported_backends = ["cpu" , "igpu" , "npu" , "hybrid" , "cuda" ]
94+ supported_backends = ["cpu" , "igpu" , "npu" , "hybrid" ]
9395 supported_recipes = [f"oga-{ backend } " for backend in supported_backends ]
9496 if recipe not in supported_recipes :
9597 raise NotSupported (
9698 "Selected OGA recipe is not supported. "
9799 f"The supported OGA recipes are: { supported_recipes } "
98100 )
99101
102+ backend_to_dtype = {
103+ "cpu" : "fp32" ,
104+ "igpu" : "fp16" ,
105+ "hybrid" : "int4" ,
106+ "npu" : "int4" ,
107+ }
108+
100109 state = _make_state (recipe , checkpoint )
101110
102111 state = oga .OgaLoad ().run (
103112 state ,
104113 input = checkpoint ,
105114 device = user_backend ,
106- dtype = "int4" ,
115+ dtype = backend_to_dtype [ user_backend ] ,
107116 )
108117
109118 return state .model , state .tokenizer
0 commit comments