diff --git a/animatediff/context.py b/animatediff/context.py index aa369b3..b3a61fd 100644 --- a/animatediff/context.py +++ b/animatediff/context.py @@ -616,6 +616,8 @@ def generate_context_visualization(context_opts: ContextOptionsGroup, model: Mod if start_step is None: start_step = 0 # use this in case start_step is provided, to display accurate step + if steps is None: + steps = len(sigmas) for i, t in enumerate(sigmas): # make context_opts reflect current step/sigma @@ -674,5 +676,5 @@ def generate_context_visualization(context_opts: ContextOptionsGroup, model: Mod repeat_count += 1 images = torch.stack(all_imgs) - images = images.movedim(1, -1) + images = images.movedim(1, -1).to(torch.float32) return images