Skip to content

Commit b809b69

Browse files
authored
Bug fix: Enable fast to override quantize json (#1377)
* Bug fix: Enable fast to override quantize json * collapse conditional
1 parent 8dccc5a commit b809b69

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

torchchat/cli/cli.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -533,16 +533,16 @@ def arg_init(args):
533533
# Localized import to minimize expensive imports
534534
from torchchat.utils.build_utils import get_device_str
535535

536-
if args.device is None or args.device == "fast":
536+
if args.device is None:
537537
args.device = get_device_str(
538538
args.quantize.get("executor", {}).get("accelerator", default_device)
539539
)
540540
else:
541+
args.device = get_device_str(args.device)
541542
executor_handler = args.quantize.get("executor", None)
542-
if executor_handler:
543-
if executor_handler["accelerator"] != args.device:
544-
print('overriding json-specified device {executor_handler["accelerator"]} with cli device {args.device}')
545-
executor_handler["accelerator"] = args.device
543+
if executor_handler and executor_handler["accelerator"] != args.device:
544+
print(f'overriding json-specified device {executor_handler["accelerator"]} with cli device {args.device}')
545+
executor_handler["accelerator"] = args.device
546546

547547
if "mps" in args.device:
548548
if getattr(args, "compile", False) or getattr(args, "compile_prefill", False):

0 commit comments

Comments
 (0)