Skip to content

Commit 8f2d13c

Browse files
authored
Fix setting fp16 dtype in AnimateDiff convert script. (huggingface#7127)
* update * update
1 parent fcfa270 commit 8f2d13c

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

scripts/convert_animatediff_motion_module_to_diffusers.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def get_args():
3030
parser.add_argument("--output_path", type=str, required=True)
3131
parser.add_argument("--use_motion_mid_block", action="store_true")
3232
parser.add_argument("--motion_max_seq_length", type=int, default=32)
33+
parser.add_argument("--save_fp16", action="store_true")
3334

3435
return parser.parse_args()
3536

@@ -48,4 +49,6 @@ def get_args():
4849
# skip loading position embeddings
4950
adapter.load_state_dict(conv_state_dict, strict=False)
5051
adapter.save_pretrained(args.output_path)
51-
adapter.save_pretrained(args.output_path, variant="fp16", torch_dtype=torch.float16)
52+
53+
if args.save_fp16:
54+
adapter.to(torch.float16).save_pretrained(args.output_path, variant="fp16")

0 commit comments

Comments
 (0)