diff --git a/examples/qwen/convert_checkpoint.py b/examples/qwen/convert_checkpoint.py index 4734e836e..954646519 100644 --- a/examples/qwen/convert_checkpoint.py +++ b/examples/qwen/convert_checkpoint.py @@ -201,7 +201,8 @@ def args_to_build_options(args): 'embedding_sharding_dim': args.embedding_sharding_dim, 'share_embedding_table': args.use_embedding_sharing, 'disable_weight_only_quant_plugin': - args.disable_weight_only_quant_plugin + args.disable_weight_only_quant_plugin, + 'load_model_on_cpu': args.load_model_on_cpu, } @@ -232,6 +233,7 @@ def convert_and_save_hf(args): dtype=args.dtype, mapping=mapping, quant_config=quant_config, + device='cpu' if args.load_model_on_cpu else 'cuda', calib_dataset=args.calib_dataset, **override_fields) else: