|
16 | 16 | from fnmatch import fnmatch
|
17 | 17 | from queue import Queue
|
18 | 18 | from packaging.version import Version
|
19 |
| -from huggingface_hub import snapshot_download |
| 19 | +from huggingface_hub import snapshot_download, list_repo_files |
20 | 20 | import onnxruntime_genai as og
|
21 | 21 | import onnxruntime_genai.models.builder as model_builder
|
22 | 22 | from turnkeyml.state import State
|
@@ -245,7 +245,7 @@ class OgaLoad(FirstTool):
|
245 | 245 | Models on Hugging Face that follow the "amd/**-onnx-ryzen-strix" pattern
|
246 | 246 | Local models for cpu, igpu, or npu:
|
247 | 247 | The specified checkpoint is converted to a local path, via mapping to lower case
|
248 |
| - and replacing '/' with '_'. If this model already exists in the 'models' folderr |
| 248 | + and replacing '/' with '_'. If this model already exists in the 'models' folder |
249 | 249 | of the lemonade cache and if it has a subfolder <device>-<dtype>, then this model
|
250 | 250 | will be used. If the --force flag is used and the model is built with model_builder,
|
251 | 251 | then it will be rebuilt.
|
@@ -398,8 +398,16 @@ def run(
|
398 | 398 | + "."
|
399 | 399 | )
|
400 | 400 |
|
| 401 | + # Check whether the model is a safetensors checkpoint or a pre-exported |
| 402 | + # ONNX model |
| 403 | + # Note: This approach only supports ONNX models where the ONNX files are in the |
| 404 | + # Huggingface repo root. This does not support the case where the ONNX files |
| 405 | + # are in a nested directory within the repo. |
| 406 | + model_files = list_repo_files(repo_id=checkpoint) |
| 407 | + onnx_model = any([filename.endswith(".onnx") for filename in model_files]) |
| 408 | + |
401 | 409 | # Download the model from HF
|
402 |
| - if device == "npu" or device == "hybrid": |
| 410 | + if onnx_model: |
403 | 411 |
|
404 | 412 | # NPU models on HF are ready to go and HF does its own caching
|
405 | 413 | full_model_path = snapshot_download(
|
@@ -474,7 +482,7 @@ def run(
|
474 | 482 | os.makedirs(os.path.dirname(dst_dll), exist_ok=True)
|
475 | 483 | shutil.copy2(src_dll, dst_dll)
|
476 | 484 | else:
|
477 |
| - # device is 'cpu' or 'igpu' |
| 485 | + # checkpoint is safetensors, so we need to run it through model_builder |
478 | 486 |
|
479 | 487 | # Use model_builder to download model and convert to ONNX
|
480 | 488 | printing.log_info(f"Building {checkpoint} for {device} using {dtype}")
|
|
0 commit comments