Skip to content

Commit b805839

Browse files
authored
Update OGA LEAP recipes (#289)
1 parent d43e50f commit b805839

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

src/lemonade/leap.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def from_pretrained(
4242
Recipe choices:
4343
- hf-cpu: Huggingface Transformers implementation for CPU with max-perf settings
4444
- hf-dgpu: Huggingface Transformers implementation on dGPU (via device="cuda")
45-
- oga-dml: DirectML implementation for iGPU based on onnxruntime-genai
45+
- oga-cpu: CPU implementation based on onnxruntime-genai
46+
- oga-dml: DirectML implementation for iGPU based on onnxruntime-genai-directml
47+
- oga-hybird: AMD Ryzen AI Hybrid implementation based on onnxruntime-genai
4648
4749
Returns:
4850
- model: LLM instance with a generate() method that invokes the recipe
@@ -89,21 +91,28 @@ def from_pretrained(
8991

9092
# Make sure the user chose a supported runtime, e.g., oga-cpu
9193
user_backend = recipe.split("oga-")[1]
92-
supported_backends = ["cpu", "igpu", "npu", "hybrid", "cuda"]
94+
supported_backends = ["cpu", "igpu", "npu", "hybrid"]
9395
supported_recipes = [f"oga-{backend}" for backend in supported_backends]
9496
if recipe not in supported_recipes:
9597
raise NotSupported(
9698
"Selected OGA recipe is not supported. "
9799
f"The supported OGA recipes are: {supported_recipes}"
98100
)
99101

102+
backend_to_dtype = {
103+
"cpu": "fp32",
104+
"igpu": "fp16",
105+
"hybrid": "int4",
106+
"npu": "int4",
107+
}
108+
100109
state = _make_state(recipe, checkpoint)
101110

102111
state = oga.OgaLoad().run(
103112
state,
104113
input=checkpoint,
105114
device=user_backend,
106-
dtype="int4",
115+
dtype=backend_to_dtype[user_backend],
107116
)
108117

109118
return state.model, state.tokenizer

src/turnkeyml/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "5.1.0"
1+
__version__ = "5.1.1"

0 commit comments

Comments
 (0)