From 67c729d448dad3fa3050e45e2888de298e8db7fd Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 1 Oct 2024 03:31:06 +0200 Subject: [PATCH 01/61] start pyramid attention broadcast --- src/diffusers/models/attention_processor.py | 2 +- .../pipelines/cogvideo/pipeline_cogvideox.py | 6 +- .../pipelines/pyramid_broadcast_utils.py | 111 ++++++++++++++++++ 3 files changed, 117 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/pipelines/pyramid_broadcast_utils.py diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9f9bc5a46e10..a207770f2f30 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -477,7 +477,7 @@ def forward( # For standard processors that are defined here, `**cross_attention_kwargs` is empty attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - quiet_attn_parameters = {"ip_adapter_masks"} + quiet_attn_parameters = {"ip_adapter_masks", "image_rotary_emb"} unused_kwargs = [ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters ] diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 82839ffd2c92..fe35c85dcc2f 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -29,6 +29,7 @@ from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor +from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin from .pipeline_output import CogVideoXPipelineOutput @@ -137,7 +138,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): +class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin): r""" Pipeline for text-to-video generation using CogVideoX. @@ -605,6 +606,7 @@ def __call__( ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Default call parameters @@ -674,6 +676,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -729,6 +732,7 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + self._current_timestep = None if not output_type == "latent": video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) diff --git a/src/diffusers/pipelines/pyramid_broadcast_utils.py b/src/diffusers/pipelines/pyramid_broadcast_utils.py new file mode 100644 index 000000000000..23d3a6b25ca2 --- /dev/null +++ b/src/diffusers/pipelines/pyramid_broadcast_utils.py @@ -0,0 +1,111 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch.nn as nn + +from ..models.attention_processor import Attention, AttentionProcessor +from ..utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class PyramidAttentionBroadcastAttentionProcessor: + def __init__(self, pipeline, processor: AttentionProcessor) -> None: + self.pipeline = pipeline + self._original_processor = processor + self._prev_hidden_states = None + self._iteration = 0 + + def __call__(self, *args, **kwargs): + if ( + hasattr(self.pipeline, "_current_timestep") + and self.pipeline._current_timestep is not None + and self._iteration % self.pipeline._pab_skip_range != 0 + and ( + self.pipeline._pab_timestep_range[0] + < self.pipeline._current_timestep + < self.pipeline._pab_timestep_range[1] + ) + ): + # print("Using cached states:", self.pipeline._current_timestep) + hidden_states = self._prev_hidden_states + else: + hidden_states = self._original_processor(*args, **kwargs) + self._prev_hidden_states = hidden_states + + self._iteration = (self._iteration + 1) % self.pipeline.num_timesteps + + return hidden_states + + +class PyramidAttentionBroadcastMixin: + r"""Mixin class for [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588).""" + + def _enable_pyramid_attention_broadcast(self) -> None: + # def is_fake_integral_match(layer_id, name): + # layer_id = layer_id.split(".")[-1] + # name = name.split(".")[-1] + # return layer_id.isnumeric() and name.isnumeric() and layer_id == name + + denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet + + for name, module in denoiser.named_modules(): + if isinstance(module, Attention): + module.processor = PyramidAttentionBroadcastAttentionProcessor(self, module.processor) + + # target_modules = {} + + # for layer_id in self._pab_skip_range: + # for name, module in denoiser.named_modules(): + # if ( + # isinstance(module, Attention) + # and re.search(layer_id, name) is not None + # and not is_fake_integral_match(layer_id, name) + # ): + # target_modules[name] = module + + # for name, module in target_modules.items(): + # # TODO: make this debug + # logger.info(f"Enabling Pyramid Attention Broadcast in layer: {name}") + # module.processor = PyramidAttentionBroadcastAttentionProcessor(self, module.processor) + + def _disable_pyramid_attention_broadcast(self) -> None: + denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet + for name, module in denoiser.named_modules(): + if isinstance(module, Attention) and isinstance( + module.processor, PyramidAttentionBroadcastAttentionProcessor + ): + # TODO: make this debug + logger.info(f"Disabling Pyramid Attention Broadcast in layer: {name}") + module.processor = module.processor._original_processor + + def enable_pyramid_attention_broadcast(self, skip_range: int, timestep_range: Tuple[int, int]) -> None: + if isinstance(skip_range, str): + skip_range = [skip_range] + + self._pab_skip_range = skip_range + self._pab_timestep_range = timestep_range + + self._enable_pyramid_attention_broadcast() + + def disable_pyramid_attention_broadcast(self) -> None: + self._pab_timestep_range = None + self._pab_skip_range = None + + @property + def pyramid_attention_broadcast_enabled(self): + return hasattr(self, "_pab_skip_range") and self._pab_skip_range is not None From 6d3bdb55110c50cc81ed716153e8cc47b579bfb4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 3 Oct 2024 08:34:57 +0200 Subject: [PATCH 02/61] add coauthor Co-Authored-By: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> From 373710167a3fb92102a247122195bf6a0463d269 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 3 Oct 2024 09:06:02 +0200 Subject: [PATCH 03/61] update --- .../pipelines/cogvideo/pipeline_cogvideox.py | 1 + .../pipeline_cogvideox_image2video.py | 7 +- .../pipeline_cogvideox_video2video.py | 7 +- .../pipelines/latte/pipeline_latte.py | 13 ++- .../pipelines/pyramid_broadcast_utils.py | 92 +++++++++++-------- 5 files changed, 76 insertions(+), 44 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index fe35c85dcc2f..57d03fe5c78b 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -733,6 +733,7 @@ def __call__( progress_bar.update() self._current_timestep = None + if not output_type == "latent": video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index afc11bce00d5..61198882f51b 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -33,6 +33,7 @@ ) from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor +from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin from .pipeline_output import CogVideoXPipelineOutput @@ -152,7 +153,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class CogVideoXImageToVideoPipeline(DiffusionPipeline): +class CogVideoXImageToVideoPipeline(DiffusionPipeline, PyramidAttentionBroadcastMixin): r""" Pipeline for image-to-video generation using CogVideoX. @@ -679,6 +680,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) self._guidance_scale = guidance_scale + self._current_timestep = None self._interrupt = False # 2. Default call parameters @@ -753,6 +755,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -810,6 +813,8 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + self._current_timestep = None + if not output_type == "latent": video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 35f8f2fa0508..149fac200b69 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -30,6 +30,7 @@ from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor +from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin from .pipeline_output import CogVideoXPipelineOutput @@ -159,7 +160,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): +class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin): r""" Pipeline for video-to-video generation using CogVideoX. @@ -679,6 +680,7 @@ def __call__( ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Default call parameters @@ -755,6 +757,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -810,6 +813,8 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + self._current_timestep = None + if not output_type == "latent": video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index d4bedf2e0e2a..5dff04d6f5c1 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -37,6 +37,7 @@ ) from ...utils.torch_utils import is_compiled_module, randn_tensor from ...video_processor import VideoProcessor +from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -132,7 +133,7 @@ class LattePipelineOutput(BaseOutput): frames: torch.Tensor -class LattePipeline(DiffusionPipeline): +class LattePipeline(DiffusionPipeline, PyramidAttentionBroadcastMixin): r""" Pipeline for text-to-video generation using Latte. @@ -623,7 +624,7 @@ def __call__( clean_caption: bool = True, mask_feature: bool = True, enable_temporal_attentions: bool = True, - decode_chunk_size: Optional[int] = None, + decode_chunk_size: int = 14, ) -> Union[LattePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -719,6 +720,7 @@ def __call__( negative_prompt_embeds, ) self._guidance_scale = guidance_scale + self._current_timestep = None self._interrupt = False # 2. Default height and width to transformer @@ -780,6 +782,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -836,8 +839,10 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if not output_type == "latents": - video = self.decode_latents(latents, video_length, decode_chunk_size=14) + self._current_timestep = None + + if not output_type == "latent": + video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: video = latents diff --git a/src/diffusers/pipelines/pyramid_broadcast_utils.py b/src/diffusers/pipelines/pyramid_broadcast_utils.py index 23d3a6b25ca2..e03c4b9dc7fa 100644 --- a/src/diffusers/pipelines/pyramid_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_broadcast_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Optional, Tuple import torch.nn as nn @@ -24,9 +24,14 @@ class PyramidAttentionBroadcastAttentionProcessor: - def __init__(self, pipeline, processor: AttentionProcessor) -> None: + def __init__( + self, pipeline, processor: AttentionProcessor, skip_range: int, timestep_range: Tuple[int, int] + ) -> None: self.pipeline = pipeline self._original_processor = processor + self._skip_range = skip_range + self._timestep_range = timestep_range + self._prev_hidden_states = None self._iteration = 0 @@ -34,14 +39,9 @@ def __call__(self, *args, **kwargs): if ( hasattr(self.pipeline, "_current_timestep") and self.pipeline._current_timestep is not None - and self._iteration % self.pipeline._pab_skip_range != 0 - and ( - self.pipeline._pab_timestep_range[0] - < self.pipeline._current_timestep - < self.pipeline._pab_timestep_range[1] - ) + and self._iteration % self._skip_range != 0 + and (self._timestep_range[0] < self.pipeline._current_timestep < self._timestep_range[1]) ): - # print("Using cached states:", self.pipeline._current_timestep) hidden_states = self._prev_hidden_states else: hidden_states = self._original_processor(*args, **kwargs) @@ -56,32 +56,26 @@ class PyramidAttentionBroadcastMixin: r"""Mixin class for [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588).""" def _enable_pyramid_attention_broadcast(self) -> None: - # def is_fake_integral_match(layer_id, name): - # layer_id = layer_id.split(".")[-1] - # name = name.split(".")[-1] - # return layer_id.isnumeric() and name.isnumeric() and layer_id == name - denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet for name, module in denoiser.named_modules(): if isinstance(module, Attention): - module.processor = PyramidAttentionBroadcastAttentionProcessor(self, module.processor) + logger.debug(f"Enabling Pyramid Attention Broadcast in layer: {name}") - # target_modules = {} + skip_range, timestep_range = None, None + if module.is_cross_attention and self._pab_cross_attn_skip_range is not None: + skip_range = self._pab_cross_attn_skip_range + timestep_range = self._pab_cross_attn_timestep_range + if not module.is_cross_attention and self._pab_spatial_attn_skip_range is not None: + skip_range = self._pab_spatial_attn_skip_range + timestep_range = self._pab_spatial_attn_timestep_range - # for layer_id in self._pab_skip_range: - # for name, module in denoiser.named_modules(): - # if ( - # isinstance(module, Attention) - # and re.search(layer_id, name) is not None - # and not is_fake_integral_match(layer_id, name) - # ): - # target_modules[name] = module + if skip_range is None: + continue - # for name, module in target_modules.items(): - # # TODO: make this debug - # logger.info(f"Enabling Pyramid Attention Broadcast in layer: {name}") - # module.processor = PyramidAttentionBroadcastAttentionProcessor(self, module.processor) + module.processor = PyramidAttentionBroadcastAttentionProcessor( + self, module.processor, skip_range, timestep_range + ) def _disable_pyramid_attention_broadcast(self) -> None: denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet @@ -89,23 +83,45 @@ def _disable_pyramid_attention_broadcast(self) -> None: if isinstance(module, Attention) and isinstance( module.processor, PyramidAttentionBroadcastAttentionProcessor ): - # TODO: make this debug - logger.info(f"Disabling Pyramid Attention Broadcast in layer: {name}") + logger.debug(f"Disabling Pyramid Attention Broadcast in layer: {name}") module.processor = module.processor._original_processor - def enable_pyramid_attention_broadcast(self, skip_range: int, timestep_range: Tuple[int, int]) -> None: - if isinstance(skip_range, str): - skip_range = [skip_range] + def enable_pyramid_attention_broadcast( + self, + spatial_attn_skip_range: Optional[int] = None, + cross_attn_skip_range: Optional[int] = None, + spatial_attn_timestep_range: Optional[Tuple[int, int]] = None, + cross_attn_timestep_range: Optional[Tuple[int, int]] = None, + ) -> None: + if spatial_attn_timestep_range is None: + spatial_attn_timestep_range = (100, 800) + if cross_attn_skip_range is None: + cross_attn_timestep_range = (100, 800) + + if spatial_attn_timestep_range[0] > spatial_attn_timestep_range[1]: + raise ValueError( + "Expected `spatial_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied." + ) + if cross_attn_timestep_range[0] > cross_attn_timestep_range[1]: + raise ValueError( + "Expected `cross_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied." + ) - self._pab_skip_range = skip_range - self._pab_timestep_range = timestep_range + self._pab_spatial_attn_skip_range = spatial_attn_skip_range + self._pab_cross_attn_skip_range = cross_attn_skip_range + self._pab_spatial_attn_timestep_range = spatial_attn_timestep_range + self._pab_cross_attn_timestep_range = cross_attn_timestep_range + self._pab_enabled = spatial_attn_skip_range or cross_attn_skip_range self._enable_pyramid_attention_broadcast() def disable_pyramid_attention_broadcast(self) -> None: - self._pab_timestep_range = None - self._pab_skip_range = None + self._pab_spatial_attn_skip_range = None + self._pab_cross_attn_skip_range = None + self._pab_spatial_attn_timestep_range = None + self._pab_cross_attn_timestep_range = None + self._pab_enabled = False @property def pyramid_attention_broadcast_enabled(self): - return hasattr(self, "_pab_skip_range") and self._pab_skip_range is not None + return hasattr(self, "_pab_enabled") and self._pab_enabled From d5c738defee03c3302eded808f3b313f0ff0c1b1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 3 Oct 2024 09:06:19 +0200 Subject: [PATCH 04/61] make style --- src/diffusers/pipelines/latte/pipeline_latte.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index 5dff04d6f5c1..e08b42c69898 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -840,7 +840,7 @@ def __call__( progress_bar.update() self._current_timestep = None - + if not output_type == "latent": video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size) video = self.video_processor.postprocess_video(video=video, output_type=output_type) From ae4abb14a3e2bb43e9d9cb00e18f75e4f3db06de Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 3 Oct 2024 21:21:35 +0200 Subject: [PATCH 05/61] update --- src/diffusers/models/attention_processor.py | 2 +- .../pipelines/pyramid_broadcast_utils.py | 32 +++++++++++++++---- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a207770f2f30..9f9bc5a46e10 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -477,7 +477,7 @@ def forward( # For standard processors that are defined here, `**cross_attention_kwargs` is empty attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - quiet_attn_parameters = {"ip_adapter_masks", "image_rotary_emb"} + quiet_attn_parameters = {"ip_adapter_masks"} unused_kwargs = [ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters ] diff --git a/src/diffusers/pipelines/pyramid_broadcast_utils.py b/src/diffusers/pipelines/pyramid_broadcast_utils.py index e03c4b9dc7fa..2d2f09d28490 100644 --- a/src/diffusers/pipelines/pyramid_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_broadcast_utils.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from typing import Optional, Tuple +import torch import torch.nn as nn from ..models.attention_processor import Attention, AttentionProcessor @@ -23,7 +25,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class PyramidAttentionBroadcastAttentionProcessor: +class PyramidAttentionBroadcastAttentionProcessorWrapper: def __init__( self, pipeline, processor: AttentionProcessor, skip_range: int, timestep_range: Tuple[int, int] ) -> None: @@ -35,7 +37,21 @@ def __init__( self._prev_hidden_states = None self._iteration = 0 - def __call__(self, *args, **kwargs): + _original_processor_params = set(inspect.signature(self._original_processor).parameters.keys()) + _supported_parameters = {"attn", "hidden_states", "encoder_hidden_states", "attention_mask", "temb", "image_rotary_emb"} + self._attn_processor_params = _supported_parameters.intersection(_original_processor_params) + + def __call__( + self, + attn: Attention, + hidden_states: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ): if ( hasattr(self.pipeline, "_current_timestep") and self.pipeline._current_timestep is not None @@ -44,7 +60,11 @@ def __call__(self, *args, **kwargs): ): hidden_states = self._prev_hidden_states else: - hidden_states = self._original_processor(*args, **kwargs) + call_kwargs = {} + for param in self._attn_processor_params: + call_kwargs.update({param: locals()[param]}) + call_kwargs.update(kwargs) + hidden_states = self._original_processor(*args, **call_kwargs) self._prev_hidden_states = hidden_states self._iteration = (self._iteration + 1) % self.pipeline.num_timesteps @@ -73,15 +93,15 @@ def _enable_pyramid_attention_broadcast(self) -> None: if skip_range is None: continue - module.processor = PyramidAttentionBroadcastAttentionProcessor( + module.set_processor(PyramidAttentionBroadcastAttentionProcessorWrapper( self, module.processor, skip_range, timestep_range - ) + )) def _disable_pyramid_attention_broadcast(self) -> None: denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet for name, module in denoiser.named_modules(): if isinstance(module, Attention) and isinstance( - module.processor, PyramidAttentionBroadcastAttentionProcessor + module.processor, PyramidAttentionBroadcastAttentionProcessorWrapper ): logger.debug(f"Disabling Pyramid Attention Broadcast in layer: {name}") module.processor = module.processor._original_processor From 9f6987fdb0c21c30cd80feace15e5178a9ce5137 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 3 Oct 2024 21:22:26 +0200 Subject: [PATCH 06/61] make style --- .../pipelines/pyramid_broadcast_utils.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/pyramid_broadcast_utils.py b/src/diffusers/pipelines/pyramid_broadcast_utils.py index 2d2f09d28490..78ca9e869c3a 100644 --- a/src/diffusers/pipelines/pyramid_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_broadcast_utils.py @@ -38,7 +38,14 @@ def __init__( self._iteration = 0 _original_processor_params = set(inspect.signature(self._original_processor).parameters.keys()) - _supported_parameters = {"attn", "hidden_states", "encoder_hidden_states", "attention_mask", "temb", "image_rotary_emb"} + _supported_parameters = { + "attn", + "hidden_states", + "encoder_hidden_states", + "attention_mask", + "temb", + "image_rotary_emb", + } self._attn_processor_params = _supported_parameters.intersection(_original_processor_params) def __call__( @@ -93,9 +100,11 @@ def _enable_pyramid_attention_broadcast(self) -> None: if skip_range is None: continue - module.set_processor(PyramidAttentionBroadcastAttentionProcessorWrapper( - self, module.processor, skip_range, timestep_range - )) + module.set_processor( + PyramidAttentionBroadcastAttentionProcessorWrapper( + self, module.processor, skip_range, timestep_range + ) + ) def _disable_pyramid_attention_broadcast(self) -> None: denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet From b3547c647ecae96024c555c43b220a6614d076c6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 4 Oct 2024 08:16:24 +0200 Subject: [PATCH 07/61] add docs --- docs/source/en/api/pipelines/cogvideox.md | 41 +++++++++++- docs/source/en/api/pipelines/latte.md | 33 +++++++++- .../pipelines/pyramid_broadcast_utils.py | 62 +++++++++++++++++++ 3 files changed, 134 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 4cde7a111ae6..5f1e91d52bf1 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -15,7 +15,7 @@ # CogVideoX -[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://arxiv.org/abs/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang. +[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://huggingface.co/papers/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang. The abstract from the paper is: @@ -100,6 +100,45 @@ It is also worth noting that torchao quantization is fully compatible with [torc - [torchao](https://gist.github.com/a-r-r-o-w/4d9732d17412888c885480c6521a9897) - [quanto](https://gist.github.com/a-r-r-o-w/31be62828b00a9292821b85c1017effa) +### Pyramid Attention Broadcast + +[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You. + +Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps, and re-using cached attention states. This is due to the realization that the attention states do not differ too much numerically between successive steps. This difference is most significant/prominent in the spatial attention blocks, lesser so in temporal attention blocks, and least in cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by temporal and spatial attention blocks. By combining other techniques like Sequence Parallelism and CFG Parallelism, the authors achieve near real-time video generation. + +PAB can be enabled easily on any pipeline by deriving from the [`PyramidAttentionBroadcastMixin`] and keeping track of current inference timestep in the pipeline. Minimal example to demonstrate how to use PAB with CogVideoX: + +```python +import torch +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.float16) +pipe.to("cuda") + +pipe.enable_pyramid_attention_broadcast( + spatial_attn_skip_range=2, + spatial_attn_timestep_range=[100, 850], +) + +prompt = ( + "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + "atmosphere of this unique musical performance." +) +video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] +export_to_video(video, "output.mp4", fps=8) +``` + +| model | model_memory | normal_memory | pab_memory | normal_time | pab_time | speedup | +|:----------:|:--------------:|:---------------:|:------------:|:-------------:|:----------:|:---------:| +| Cog-2b T2V | 12.55 | 35.342 | 35.342 | 86.915 | 63.914 | 1.359 | +| Cog-5b T2V | 19.66 | 40.945 | 40.945 | 246.152 | 168.59 | 1.460 | +| Cog-5b I2V | 19.764 | 42.74 | 42.74 | 246.867 | 170.111 | 1.451 | + ## CogVideoXPipeline [[autodoc]] CogVideoXPipeline diff --git a/docs/source/en/api/pipelines/latte.md b/docs/source/en/api/pipelines/latte.md index c2154d5d47c1..bc62aadbc5ee 100644 --- a/docs/source/en/api/pipelines/latte.md +++ b/docs/source/en/api/pipelines/latte.md @@ -16,7 +16,7 @@ ![latte text-to-video](https://github.com/Vchitect/Latte/blob/52bc0029899babbd6e9250384c83d8ed2670ff7a/visuals/latte.gif?raw=true) -[Latte: Latent Diffusion Transformer for Video Generation](https://arxiv.org/abs/2401.03048) from Monash University, Shanghai AI Lab, Nanjing University, and Nanyang Technological University. +[Latte: Latent Diffusion Transformer for Video Generation](https://huggingface.co/papers/2401.03048) from Monash University, Shanghai AI Lab, Nanjing University, and Nanyang Technological University. The abstract from the paper is: @@ -70,6 +70,37 @@ Without torch.compile(): Average inference time: 16.246 seconds. With torch.compile(): Average inference time: 14.573 seconds. ``` +### Pyramid Attention Broadcast + +[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You. + +Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps, and re-using cached attention states. This is due to the realization that the attention states do not differ too much numerically between successive steps. This difference is most significant/prominent in the spatial attention blocks, lesser so in temporal attention blocks, and least in cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by temporal and spatial attention blocks. By combining other techniques like Sequence Parallelism and CFG Parallelism, the authors achieve near real-time video generation. + +PAB can be enabled easily on any pipeline by deriving from the [`PyramidAttentionBroadcastMixin`] and keeping track of current inference timestep in the pipeline. Minimal example to demonstrate how to use PAB with Latte: + +```python +import torch +from diffusers import LattePipeline +from diffusers.utils import export_to_gif + +pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16) + +pipe.enable_pyramid_attention_broadcast( + spatial_attn_skip_range=2, + cross_attn_skip_range=6, + spatial_attn_timestep_range=[100, 800], + cross_attn_timestep_range=[100, 800], +) + +prompt = "A small cactus with a happy face in the Sahara desert." +videos = pipe(prompt).frames[0] +export_to_gif(videos, "latte.gif") +``` + +| model | model_memory | normal_memory | pab_memory | normal_time | pab_time | speedup | +|:----------:|:--------------:|:---------------:|:------------:|:-------------:|:----------:|:---------:| +| Latte | 11.007 | 25.594 | 25.594 | 28.026 | 24.073 | 1.164 | + ## LattePipeline [[autodoc]] LattePipeline diff --git a/src/diffusers/pipelines/pyramid_broadcast_utils.py b/src/diffusers/pipelines/pyramid_broadcast_utils.py index 78ca9e869c3a..f3be47757cdd 100644 --- a/src/diffusers/pipelines/pyramid_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_broadcast_utils.py @@ -26,6 +26,35 @@ class PyramidAttentionBroadcastAttentionProcessorWrapper: + r""" + Helper attention processor that wraps logic required for Pyramid Attention Broadcast to function. + + PAB works by caching and re-using attention computations from past inference steps. This is due to the realization + that the attention states do not differ too much numerically between successive inference steps. The difference is + most significant/prominent in the spatial attention blocks, lesser so in the temporal attention blocks, and least + in cross attention blocks. + + Currently, only spatial and cross attention block skipping is supported in Diffusers due to not having any models + tested with temporal attention blocks. Feel free to open a PR adding support for this in case there's a model that + you would like to use PAB with. + + Args: + pipeline ([`~diffusers.DiffusionPipeline`]): + The underlying DiffusionPipeline object that inherits from the PAB Mixin and utilized this attention + processor. + processor ([`~diffusers.models.attention_processor.AttentionProcessor`]): + The underlying attention processor that will be wrapped to cache the intermediate attention computation. + skip_range (`int`): + The attention block to execute after skipping intermediate attention blocks. If set to the value `N`, `N - + 1` attention blocks are skipped and every N'th block is executed. Different models have different + tolerances to how much attention computation can be reused based on the differences between successive + blocks. So, this parameter must be adjusted per model after performing experimentation. Setting this value + to `2` is recommended for different models PAB has been experimented with. + timestep_range (`Tuple[int, int]`): + The timestep range between which PAB will remain activated in attention blocks. While activated, PAB will + re-use attention computations between inference steps. + """ + def __init__( self, pipeline, processor: AttentionProcessor, skip_range: int, timestep_range: Tuple[int, int] ) -> None: @@ -59,14 +88,18 @@ def __call__( *args, **kwargs, ): + r"""Method that wraps the underlying call to compute attention and cache states for re-use.""" + if ( hasattr(self.pipeline, "_current_timestep") and self.pipeline._current_timestep is not None and self._iteration % self._skip_range != 0 and (self._timestep_range[0] < self.pipeline._current_timestep < self._timestep_range[1]) ): + # Skip attention computation by re-using past attention states hidden_states = self._prev_hidden_states else: + # Perform attention computation call_kwargs = {} for param in self._attn_processor_params: call_kwargs.update({param: locals()[param]}) @@ -122,6 +155,33 @@ def enable_pyramid_attention_broadcast( spatial_attn_timestep_range: Optional[Tuple[int, int]] = None, cross_attn_timestep_range: Optional[Tuple[int, int]] = None, ) -> None: + r""" + Enable pyramid attention broadcast to speedup inference by re-using attention states and skipping computation + systematically as described in the paper: [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588). + + Args: + spatial_attn_skip_range (`int`, *optional*): + The attention block to execute after skipping intermediate spatial attention blocks. If set to the + value `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have + different tolerances to how much attention computation can be reused based on the differences between + successive blocks. So, this parameter must be adjusted per model after performing experimentation. + Setting this value to `2` is recommended for different models PAB has been experimented with. + cross_attn_skip_range (`int`, *optional*): + The attention block to execute after skipping intermediate cross attention blocks. If set to the value + `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have + different tolerances to how much attention computation can be reused based on the differences between + successive blocks. So, this parameter must be adjusted per model after performing experimentation. + Setting this value to `6` is recommended for different models PAB has been experimented with. + spatial_attn_timestep_range (`Tuple[int, int]`, *optional*): + The timestep range between which PAB will remain activated in spatial attention blocks. While + activated, PAB will re-use attention computations between inference steps. Setting this to `(100, 850)` + is recommended for different models PAB has been experimented with. + cross_attn_timestep_range (`Tuple[int, int]`, *optional*): + The timestep range between which PAB will remain activated in cross attention blocks. While activated, + PAB will re-use attention computations between inference steps. Setting this to `(100, 800)` is + recommended for different models PAB has been experimented with. + """ + if spatial_attn_timestep_range is None: spatial_attn_timestep_range = (100, 800) if cross_attn_skip_range is None: @@ -145,6 +205,8 @@ def enable_pyramid_attention_broadcast( self._enable_pyramid_attention_broadcast() def disable_pyramid_attention_broadcast(self) -> None: + r"""Disables the pyramid attention broadcast sampling mechanism.""" + self._pab_spatial_attn_skip_range = None self._pab_cross_attn_skip_range = None self._pab_spatial_attn_timestep_range = None From afd0c176d1f517b7cd9165c181ad2c95c980d4d8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 4 Oct 2024 08:51:21 +0200 Subject: [PATCH 08/61] add tests --- .../transformers/latte_transformer_3d.py | 64 ++++++++++++++++++- tests/pipelines/cogvideo/test_cogvideox.py | 53 ++++++++++++++- .../cogvideo/test_cogvideox_image2video.py | 53 ++++++++++++++- .../cogvideo/test_cogvideox_video2video.py | 53 ++++++++++++++- tests/pipelines/latte/test_latte.py | 45 ++++++++++++- 5 files changed, 259 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 71d19216e5ff..18efebba9e19 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional + +from typing import Dict, Optional, Union import torch from torch import nn @@ -19,6 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid from ..attention import BasicTransformerBlock +from ..attention_processor import AttentionProcessor from ..embeddings import PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -165,6 +167,66 @@ def __init__( def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + def forward( self, hidden_states: torch.Tensor, diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index 884ddfb2a95a..bbafd3aa7ae2 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -21,6 +21,7 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler +from diffusers.pipelines.pyramid_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper from diffusers.utils.testing_utils import ( enable_full_determinism, numpy_cosine_similarity_distance, @@ -59,7 +60,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = CogVideoXTransformer3DModel( # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings @@ -71,7 +72,7 @@ def get_dummy_components(self): out_channels=4, time_embed_dim=2, text_embed_dim=32, # Must match with tiny-random-t5 - num_layers=1, + num_layers=num_layers, sample_width=2, # latent width: 2 -> final width: 16 sample_height=2, # latent height: 2 -> final height: 16 sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 @@ -319,6 +320,54 @@ def test_fused_qkv_projections(self): original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 ), "Original outputs should match when fused QKV projections are disabled." + def test_pyramid_attention_broadcast(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + num_layers = 4 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + frames = pipe(**inputs).frames # [B, F, C, H, W] + original_image_slice = frames[0, -2:, -1, -3:, -3:] + + pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800)) + assert pipe.pyramid_attention_broadcast_enabled + + num_pab_processors = sum( + [ + isinstance(processor, PyramidAttentionBroadcastAttentionProcessorWrapper) + for processor in pipe.transformer.attn_processors.values() + ] + ) + assert num_pab_processors == num_layers + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + frames = pipe(**inputs).frames + image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:] + + pipe.disable_pyramid_attention_broadcast() + assert not pipe.pyramid_attention_broadcast_enabled + + inputs = self.get_dummy_inputs(device) + frames = pipe(**inputs).frames + image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:] + + # We need to use higher tolerance because we are using a random model. With a converged/trained + # model, the tolerance can be lower. + assert np.allclose( + original_image_slice, image_slice_pab_enabled, atol=0.2 + ), "PAB outputs should not differ much in specified timestep range." + assert np.allclose( + image_slice_pab_enabled, image_slice_pab_disabled, atol=0.2 + ), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range." + assert np.allclose( + original_image_slice, image_slice_pab_disabled, atol=0.2 + ), "Original outputs should match when PAB is disabled." + @slow @require_torch_gpu diff --git a/tests/pipelines/cogvideo/test_cogvideox_image2video.py b/tests/pipelines/cogvideo/test_cogvideox_image2video.py index ec9a5fdd153e..89235a226640 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_image2video.py +++ b/tests/pipelines/cogvideo/test_cogvideox_image2video.py @@ -22,6 +22,7 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler +from diffusers.pipelines.pyramid_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -61,7 +62,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = CogVideoXTransformer3DModel( # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings @@ -76,7 +77,7 @@ def get_dummy_components(self): out_channels=4, time_embed_dim=2, text_embed_dim=32, # Must match with tiny-random-t5 - num_layers=1, + num_layers=num_layers, sample_width=2, # latent width: 2 -> final width: 16 sample_height=2, # latent height: 2 -> final height: 16 sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 @@ -342,6 +343,54 @@ def test_fused_qkv_projections(self): original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 ), "Original outputs should match when fused QKV projections are disabled." + def test_pyramid_attention_broadcast(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + num_layers = 4 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + frames = pipe(**inputs).frames # [B, F, C, H, W] + original_image_slice = frames[0, -2:, -1, -3:, -3:] + + pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800)) + assert pipe.pyramid_attention_broadcast_enabled + + num_pab_processors = sum( + [ + isinstance(processor, PyramidAttentionBroadcastAttentionProcessorWrapper) + for processor in pipe.transformer.attn_processors.values() + ] + ) + assert num_pab_processors == num_layers + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + frames = pipe(**inputs).frames + image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:] + + pipe.disable_pyramid_attention_broadcast() + assert not pipe.pyramid_attention_broadcast_enabled + + inputs = self.get_dummy_inputs(device) + frames = pipe(**inputs).frames + image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:] + + # We need to use higher tolerance because we are using a random model. With a converged/trained + # model, the tolerance can be lower. + assert np.allclose( + original_image_slice, image_slice_pab_enabled, atol=0.2 + ), "PAB outputs should not differ much in specified timestep range." + assert np.allclose( + image_slice_pab_enabled, image_slice_pab_disabled, atol=0.2 + ), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range." + assert np.allclose( + original_image_slice, image_slice_pab_disabled, atol=0.2 + ), "Original outputs should match when PAB is disabled." + @unittest.skip("The model 'THUDM/CogVideoX-5b-I2V' is not public yet.") @slow diff --git a/tests/pipelines/cogvideo/test_cogvideox_video2video.py b/tests/pipelines/cogvideo/test_cogvideox_video2video.py index 4d836cb5e2a4..9f05b04d77af 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_video2video.py +++ b/tests/pipelines/cogvideo/test_cogvideox_video2video.py @@ -21,6 +21,7 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXVideoToVideoPipeline, DDIMScheduler +from diffusers.pipelines.pyramid_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS @@ -53,7 +54,7 @@ class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC ) test_xformers_attention = False - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = CogVideoXTransformer3DModel( # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings @@ -65,7 +66,7 @@ def get_dummy_components(self): out_channels=4, time_embed_dim=2, text_embed_dim=32, # Must match with tiny-random-t5 - num_layers=1, + num_layers=num_layers, sample_width=2, # latent width: 2 -> final width: 16 sample_height=2, # latent height: 2 -> final height: 16 sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 @@ -323,3 +324,51 @@ def test_fused_qkv_projections(self): assert np.allclose( original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 ), "Original outputs should match when fused QKV projections are disabled." + + def test_pyramid_attention_broadcast(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + num_layers = 4 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + frames = pipe(**inputs).frames # [B, F, C, H, W] + original_image_slice = frames[0, -2:, -1, -3:, -3:] + + pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800)) + assert pipe.pyramid_attention_broadcast_enabled + + num_pab_processors = sum( + [ + isinstance(processor, PyramidAttentionBroadcastAttentionProcessorWrapper) + for processor in pipe.transformer.attn_processors.values() + ] + ) + assert num_pab_processors == num_layers + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + frames = pipe(**inputs).frames + image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:] + + pipe.disable_pyramid_attention_broadcast() + assert not pipe.pyramid_attention_broadcast_enabled + + inputs = self.get_dummy_inputs(device) + frames = pipe(**inputs).frames + image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:] + + # We need to use higher tolerance because we are using a random model. With a converged/trained + # model, the tolerance can be lower. + assert np.allclose( + original_image_slice, image_slice_pab_enabled, atol=0.2 + ), "PAB outputs should not differ much in specified timestep range." + assert np.allclose( + image_slice_pab_enabled, image_slice_pab_disabled, atol=0.2 + ), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range." + assert np.allclose( + original_image_slice, image_slice_pab_disabled, atol=0.2 + ), "Original outputs should match when PAB is disabled." diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index 9667ebff249d..75f20a446dfd 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -53,11 +53,11 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase): required_optional_params = PipelineTesterMixin.required_optional_params - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = LatteTransformer3DModel( sample_size=8, - num_layers=1, + num_layers=num_layers, patch_size=2, attention_head_dim=8, num_attention_heads=3, @@ -264,6 +264,47 @@ def test_save_load_optional_components(self): def test_xformers_attention_forwardGenerator_pass(self): super()._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False) + def test_pyramid_attention_broadcast(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + num_layers = 4 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + frames = pipe(**inputs).frames # [B, F, C, H, W] + original_image_slice = frames[0, -2:, -1, -3:, -3:] + + pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800)) + assert pipe.pyramid_attention_broadcast_enabled + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + frames = pipe(**inputs).frames + image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:] + + pipe.disable_pyramid_attention_broadcast() + assert not pipe.pyramid_attention_broadcast_enabled + + inputs = self.get_dummy_inputs(device) + frames = pipe(**inputs).frames + image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:] + + # We need to use higher tolerance because we are using a random model. With a converged/trained + # model, the tolerance can be lower. + assert np.allclose( + original_image_slice, image_slice_pab_enabled, atol=0.25 + ), "PAB outputs should not differ much in specified timestep range." + print((image_slice_pab_disabled - image_slice_pab_enabled).abs().max()) + assert np.allclose( + image_slice_pab_enabled, image_slice_pab_disabled, atol=0.25 + ), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range." + assert np.allclose( + original_image_slice, image_slice_pab_disabled, atol=0.25 + ), "Original outputs should match when PAB is disabled." + @slow @require_torch_gpu From 6265b6546940bc1dd1f5f88309b3b814444692c2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 5 Oct 2024 22:22:01 +0200 Subject: [PATCH 09/61] update --- src/diffusers/pipelines/pyramid_broadcast_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pyramid_broadcast_utils.py b/src/diffusers/pipelines/pyramid_broadcast_utils.py index f3be47757cdd..6e917568f33a 100644 --- a/src/diffusers/pipelines/pyramid_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_broadcast_utils.py @@ -66,8 +66,8 @@ def __init__( self._prev_hidden_states = None self._iteration = 0 - _original_processor_params = set(inspect.signature(self._original_processor).parameters.keys()) - _supported_parameters = { + original_processor_params = set(inspect.signature(self._original_processor.__call__).parameters.keys()) + supported_parameters = { "attn", "hidden_states", "encoder_hidden_states", @@ -75,7 +75,7 @@ def __init__( "temb", "image_rotary_emb", } - self._attn_processor_params = _supported_parameters.intersection(_original_processor_params) + self._attn_processor_params = supported_parameters.intersection(original_processor_params) def __call__( self, From 9cb4e876bc3f0d6915a852f34667e15eafe3341e Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Oct 2024 03:26:45 +0530 Subject: [PATCH 10/61] Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/pipelines/cogvideox.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 5f1e91d52bf1..c1873bc7f5a4 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -104,7 +104,7 @@ It is also worth noting that torchao quantization is fully compatible with [torc [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You. -Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps, and re-using cached attention states. This is due to the realization that the attention states do not differ too much numerically between successive steps. This difference is most significant/prominent in the spatial attention blocks, lesser so in temporal attention blocks, and least in cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by temporal and spatial attention blocks. By combining other techniques like Sequence Parallelism and CFG Parallelism, the authors achieve near real-time video generation. +Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states aren't that different between successive steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation. PAB can be enabled easily on any pipeline by deriving from the [`PyramidAttentionBroadcastMixin`] and keeping track of current inference timestep in the pipeline. Minimal example to demonstrate how to use PAB with CogVideoX: From 6b1f55ec971dc955f810dd61a2ff619c93170273 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Oct 2024 03:26:58 +0530 Subject: [PATCH 11/61] Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/pipelines/cogvideox.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index c1873bc7f5a4..bd7e9e04787a 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -106,7 +106,7 @@ It is also worth noting that torchao quantization is fully compatible with [torc Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states aren't that different between successive steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation. -PAB can be enabled easily on any pipeline by deriving from the [`PyramidAttentionBroadcastMixin`] and keeping track of current inference timestep in the pipeline. Minimal example to demonstrate how to use PAB with CogVideoX: +Enable PAB with [`~PyramidAttentionBroadcastMixin.enable_pyramind_attention_broadcast`] on any pipeline and keep track of the current inference timestep in the pipeline. ```python import torch From c52cf422d06158deef3175e6c7180963de67ef1e Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 8 Nov 2024 21:45:58 +0530 Subject: [PATCH 12/61] Pyramid Attention Broadcast rewrite + introduce hooks (#9826) * rewrite implementation with hooks * make style * update --- src/diffusers/models/hooks.py | 251 ++++++++++++++++++ .../pipelines/allegro/pipeline_allegro.py | 7 +- .../pipeline_cogvideox_fun_control.py | 7 +- src/diffusers/pipelines/pipeline_utils.py | 4 + .../pipelines/pyramid_broadcast_utils.py | 197 ++++++-------- 5 files changed, 344 insertions(+), 122 deletions(-) create mode 100644 src/diffusers/models/hooks.py diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py new file mode 100644 index 000000000000..2b4351d4a94e --- /dev/null +++ b/src/diffusers/models/hooks.py @@ -0,0 +1,251 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Any, Callable, Dict, Tuple, Union + +import torch + + +# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py +class ModelHook: + r""" + A hook that contains callbacks to be executed just before and after the forward method of a model. The difference + with PyTorch existing hooks is that they get passed along the kwargs. + """ + + def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is initialized. + + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: + r""" + Hook that is executed just before the forward method of the model. + + Args: + module (`torch.nn.Module`): + The module whose forward pass will be executed just after this event. + args (`Tuple[Any]`): + The positional arguments passed to the module. + kwargs (`Dict[Str, Any]`): + The keyword arguments passed to the module. + + Returns: + `Tuple[Tuple[Any], Dict[Str, Any]]`: + A tuple with the treated `args` and `kwargs`. + """ + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: + r""" + Hook that is executed just after the forward method of the model. + + Args: + module (`torch.nn.Module`): + The module whose forward pass been executed just before this event. + output (`Any`): + The output of the module. + + Returns: + `Any`: The processed `output`. + """ + return output + + def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when the hook is detached from a module. + + Args: + module (`torch.nn.Module`): + The module detached from this hook. + """ + return module + + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + return module + + +class SequentialHook(ModelHook): + r"""A hook that can contain several hooks and iterates through them at each event.""" + + def __init__(self, *hooks): + self.hooks = hooks + + def init_hook(self, module): + for hook in self.hooks: + module = hook.init_hook(module) + return module + + def pre_forward(self, module, *args, **kwargs): + for hook in self.hooks: + args, kwargs = hook.pre_forward(module, *args, **kwargs) + return args, kwargs + + def post_forward(self, module, output): + for hook in self.hooks: + output = hook.post_forward(module, output) + return output + + def detach_hook(self, module): + for hook in self.hooks: + module = hook.detach_hook(module) + return module + + def reset_state(self, module): + for hook in self.hooks: + module = hook.reset_state(module) + return module + + +class PyramidAttentionBroadcastHook(ModelHook): + def __init__( + self, + skip_range: int, + timestep_range: Tuple[int, int], + timestep_callback: Callable[[], Union[torch.LongTensor, int]], + ) -> None: + super().__init__() + + self.skip_range = skip_range + self.timestep_range = timestep_range + self.timestep_callback = timestep_callback + + self.attention_cache = None + self._iteration = 0 + + def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + + current_timestep = self.timestep_callback() + is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1] + should_compute_attention = self._iteration % self.skip_range == 0 + + if not is_within_timestep_range or should_compute_attention: + output = module._old_forward(*args, **kwargs) + else: + output = self.attention_cache + + self._iteration = self._iteration + 1 + + return module._diffusers_hook.post_forward(module, output) + + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: + self.attention_cache = output + return output + + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + self.attention_cache = None + self._iteration = 0 + return module + + +def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False): + r""" + Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove + this behavior and restore the original `forward` method, use `remove_hook_from_module`. + + + + If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks + together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class. + + + + Args: + module (`torch.nn.Module`): + The module to attach a hook to. + hook (`ModelHook`): + The hook to attach. + append (`bool`, *optional*, defaults to `False`): + Whether the hook should be chained with an existing one (if module already contains a hook) or not. + + Returns: + `torch.nn.Module`: + The same module, with the hook attached (the module is modified in place, so the result can be discarded). + """ + original_hook = hook + + if append and getattr(module, "_diffusers_hook", None) is not None: + old_hook = module._diffusers_hook + remove_hook_from_module(module) + hook = SequentialHook(old_hook, hook) + + if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"): + # If we already put some hook on this module, we replace it with the new one. + old_forward = module._old_forward + else: + old_forward = module.forward + module._old_forward = old_forward + + module = hook.init_hook(module) + module._diffusers_hook = hook + + if hasattr(original_hook, "new_forward"): + new_forward = original_hook.new_forward + else: + + def new_forward(module, *args, **kwargs): + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + output = module._old_forward(*args, **kwargs) + return module._diffusers_hook.post_forward(module, output) + + # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. + # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 + if "GraphModuleImpl" in str(type(module)): + module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) + else: + module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) + + return module + + +def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module: + """ + Removes any hook attached to a module via `add_hook_to_module`. + + Args: + module (`torch.nn.Module`): + The module to attach a hook to. + recurse (`bool`, defaults to `False`): + Whether to remove the hooks recursively + + Returns: + `torch.nn.Module`: + The same module, with the hook detached (the module is modified in place, so the result can be discarded). + """ + + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook.detach_hook(module) + delattr(module, "_diffusers_hook") + + if hasattr(module, "_old_forward"): + # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. + # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 + if "GraphModuleImpl" in str(type(module)): + module.__class__.forward = module._old_forward + else: + module.forward = module._old_forward + delattr(module, "_old_forward") + + if recurse: + for child in module.children(): + remove_hook_from_module(child, recurse) + + return module diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 9314960f9618..10dd6455092d 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -38,6 +38,7 @@ ) from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor +from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin from .pipeline_output import AllegroPipelineOutput @@ -131,7 +132,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class AllegroPipeline(DiffusionPipeline): +class AllegroPipeline(DiffusionPipeline, PyramidAttentionBroadcastMixin): r""" Pipeline for text-to-video generation using Allegro. @@ -786,6 +787,7 @@ def __call__( negative_prompt_attention_mask, ) self._guidance_scale = guidance_scale + self._current_timestep = None self._interrupt = False # 2. Default height and width to transformer @@ -863,6 +865,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -901,6 +904,8 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + self._current_timestep = None + if not output_type == "latent": latents = latents.to(self.vae.dtype) video = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 3655075bd519..9eeccec50621 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -30,6 +30,7 @@ from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor +from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin from .pipeline_output import CogVideoXPipelineOutput @@ -144,7 +145,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): +class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin): r""" Pipeline for controlled text-to-video generation using CogVideoX Fun. @@ -650,6 +651,7 @@ def __call__( ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Default call parameters @@ -730,6 +732,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -779,6 +782,8 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + self._current_timestep = None + if not output_type == "latent": video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2e1858b16148..aa790c830d1a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1082,6 +1082,10 @@ def maybe_free_model_hooks(self): is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions correctly when applying enable_model_cpu_offload. """ + + if hasattr(self, "_diffusers_hook"): + self._diffusers_hook.reset_state() + if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: # `enable_model_cpu_offload` has not be called, so silently do nothing return diff --git a/src/diffusers/pipelines/pyramid_broadcast_utils.py b/src/diffusers/pipelines/pyramid_broadcast_utils.py index 6e917568f33a..7fdb6a7f5b93 100644 --- a/src/diffusers/pipelines/pyramid_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_broadcast_utils.py @@ -12,106 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from typing import Optional, Tuple +from typing import List, Optional, Tuple -import torch import torch.nn as nn -from ..models.attention_processor import Attention, AttentionProcessor +from ..models.attention_processor import Attention +from ..models.hooks import PyramidAttentionBroadcastHook, add_hook_to_module, remove_hook_from_module from ..utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class PyramidAttentionBroadcastAttentionProcessorWrapper: - r""" - Helper attention processor that wraps logic required for Pyramid Attention Broadcast to function. - - PAB works by caching and re-using attention computations from past inference steps. This is due to the realization - that the attention states do not differ too much numerically between successive inference steps. The difference is - most significant/prominent in the spatial attention blocks, lesser so in the temporal attention blocks, and least - in cross attention blocks. - - Currently, only spatial and cross attention block skipping is supported in Diffusers due to not having any models - tested with temporal attention blocks. Feel free to open a PR adding support for this in case there's a model that - you would like to use PAB with. - - Args: - pipeline ([`~diffusers.DiffusionPipeline`]): - The underlying DiffusionPipeline object that inherits from the PAB Mixin and utilized this attention - processor. - processor ([`~diffusers.models.attention_processor.AttentionProcessor`]): - The underlying attention processor that will be wrapped to cache the intermediate attention computation. - skip_range (`int`): - The attention block to execute after skipping intermediate attention blocks. If set to the value `N`, `N - - 1` attention blocks are skipped and every N'th block is executed. Different models have different - tolerances to how much attention computation can be reused based on the differences between successive - blocks. So, this parameter must be adjusted per model after performing experimentation. Setting this value - to `2` is recommended for different models PAB has been experimented with. - timestep_range (`Tuple[int, int]`): - The timestep range between which PAB will remain activated in attention blocks. While activated, PAB will - re-use attention computations between inference steps. - """ - - def __init__( - self, pipeline, processor: AttentionProcessor, skip_range: int, timestep_range: Tuple[int, int] - ) -> None: - self.pipeline = pipeline - self._original_processor = processor - self._skip_range = skip_range - self._timestep_range = timestep_range - - self._prev_hidden_states = None - self._iteration = 0 - - original_processor_params = set(inspect.signature(self._original_processor.__call__).parameters.keys()) - supported_parameters = { - "attn", - "hidden_states", - "encoder_hidden_states", - "attention_mask", - "temb", - "image_rotary_emb", - } - self._attn_processor_params = supported_parameters.intersection(original_processor_params) - - def __call__( - self, - attn: Attention, - hidden_states: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - *args, - **kwargs, - ): - r"""Method that wraps the underlying call to compute attention and cache states for re-use.""" - - if ( - hasattr(self.pipeline, "_current_timestep") - and self.pipeline._current_timestep is not None - and self._iteration % self._skip_range != 0 - and (self._timestep_range[0] < self.pipeline._current_timestep < self._timestep_range[1]) - ): - # Skip attention computation by re-using past attention states - hidden_states = self._prev_hidden_states - else: - # Perform attention computation - call_kwargs = {} - for param in self._attn_processor_params: - call_kwargs.update({param: locals()[param]}) - call_kwargs.update(kwargs) - hidden_states = self._original_processor(*args, **call_kwargs) - self._prev_hidden_states = hidden_states - - self._iteration = (self._iteration + 1) % self.pipeline.num_timesteps - - return hidden_states - - class PyramidAttentionBroadcastMixin: r"""Mixin class for [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588).""" @@ -120,40 +32,68 @@ def _enable_pyramid_attention_broadcast(self) -> None: for name, module in denoiser.named_modules(): if isinstance(module, Attention): - logger.debug(f"Enabling Pyramid Attention Broadcast in layer: {name}") + is_spatial_attention = ( + any(x in name for x in self._pab_spatial_attn_layer_identifiers) + and self._pab_spatial_attn_skip_range is not None + and not module.is_cross_attention + ) + is_temporal_attention = ( + any(x in name for x in self._pab_temporal_attn_layer_identifiers) + and self._pab_temporal_attn_skip_range is not None + and not module.is_cross_attention + ) + is_cross_attention = ( + any(x in name for x in self._pab_cross_attn_layer_identifiers) + and self._pab_cross_attn_skip_range is not None + and module.is_cross_attention + ) - skip_range, timestep_range = None, None - if module.is_cross_attention and self._pab_cross_attn_skip_range is not None: - skip_range = self._pab_cross_attn_skip_range - timestep_range = self._pab_cross_attn_timestep_range - if not module.is_cross_attention and self._pab_spatial_attn_skip_range is not None: + if is_spatial_attention: skip_range = self._pab_spatial_attn_skip_range timestep_range = self._pab_spatial_attn_timestep_range + if is_temporal_attention: + skip_range = self._pab_temporal_attn_skip_range + timestep_range = self._pab_temporal_attn_timestep_range + if is_cross_attention: + skip_range = self._pab_cross_attn_skip_range + timestep_range = self._pab_cross_attn_timestep_range if skip_range is None: continue - module.set_processor( - PyramidAttentionBroadcastAttentionProcessorWrapper( - self, module.processor, skip_range, timestep_range - ) + # logger.debug(f"Enabling Pyramid Attention Broadcast in layer: {name}") + print(f"Enabling Pyramid Attention Broadcast in layer: {name}") + + add_hook_to_module( + module, + PyramidAttentionBroadcastHook( + skip_range=skip_range, + timestep_range=timestep_range, + timestep_callback=self._pyramid_attention_broadcast_timestep_callback, + ), + append=True, ) def _disable_pyramid_attention_broadcast(self) -> None: denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet for name, module in denoiser.named_modules(): - if isinstance(module, Attention) and isinstance( - module.processor, PyramidAttentionBroadcastAttentionProcessorWrapper - ): - logger.debug(f"Disabling Pyramid Attention Broadcast in layer: {name}") - module.processor = module.processor._original_processor + logger.debug(f"Disabling Pyramid Attention Broadcast in layer: {name}") + remove_hook_from_module(module) + + def _pyramid_attention_broadcast_timestep_callback(self): + return self._current_timestep def enable_pyramid_attention_broadcast( self, spatial_attn_skip_range: Optional[int] = None, + spatial_attn_timestep_range: Tuple[int, int] = (100, 800), + temporal_attn_skip_range: Optional[int] = None, cross_attn_skip_range: Optional[int] = None, - spatial_attn_timestep_range: Optional[Tuple[int, int]] = None, - cross_attn_timestep_range: Optional[Tuple[int, int]] = None, + temporal_attn_timestep_range: Tuple[int, int] = (100, 800), + cross_attn_timestep_range: Tuple[int, int] = (100, 800), + spatial_attn_layer_identifiers: List[str] = ["blocks", "transformer_blocks"], + temporal_attn_layer_identifiers: List[str] = ["temporal_transformer_blocks"], + cross_attn_layer_identifiers: List[str] = ["blocks", "transformer_blocks"], ) -> None: r""" Enable pyramid attention broadcast to speedup inference by re-using attention states and skipping computation @@ -166,41 +106,53 @@ def enable_pyramid_attention_broadcast( different tolerances to how much attention computation can be reused based on the differences between successive blocks. So, this parameter must be adjusted per model after performing experimentation. Setting this value to `2` is recommended for different models PAB has been experimented with. + temporal_attn_skip_range (`int`, *optional*): + The attention block to execute after skipping intermediate temporal attention blocks. If set to the + value `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have + different tolerances to how much attention computation can be reused based on the differences between + successive blocks. So, this parameter must be adjusted per model after performing experimentation. + Setting this value to `4` is recommended for different models PAB has been experimented with. cross_attn_skip_range (`int`, *optional*): The attention block to execute after skipping intermediate cross attention blocks. If set to the value `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have different tolerances to how much attention computation can be reused based on the differences between successive blocks. So, this parameter must be adjusted per model after performing experimentation. Setting this value to `6` is recommended for different models PAB has been experimented with. - spatial_attn_timestep_range (`Tuple[int, int]`, *optional*): + spatial_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`): The timestep range between which PAB will remain activated in spatial attention blocks. While - activated, PAB will re-use attention computations between inference steps. Setting this to `(100, 850)` - is recommended for different models PAB has been experimented with. - cross_attn_timestep_range (`Tuple[int, int]`, *optional*): + activated, PAB will re-use attention computations between inference steps. + temporal_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`): + The timestep range between which PAB will remain activated in temporal attention blocks. While + activated, PAB will re-use attention computations between inference steps. + cross_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`): The timestep range between which PAB will remain activated in cross attention blocks. While activated, - PAB will re-use attention computations between inference steps. Setting this to `(100, 800)` is - recommended for different models PAB has been experimented with. + PAB will re-use attention computations between inference steps. """ - if spatial_attn_timestep_range is None: - spatial_attn_timestep_range = (100, 800) - if cross_attn_skip_range is None: - cross_attn_timestep_range = (100, 800) - if spatial_attn_timestep_range[0] > spatial_attn_timestep_range[1]: raise ValueError( "Expected `spatial_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied." ) + if temporal_attn_timestep_range[0] > temporal_attn_timestep_range[1]: + raise ValueError( + "Expected `temporal_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied." + ) if cross_attn_timestep_range[0] > cross_attn_timestep_range[1]: raise ValueError( "Expected `cross_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied." ) self._pab_spatial_attn_skip_range = spatial_attn_skip_range + self._pab_temporal_attn_skip_range = temporal_attn_skip_range self._pab_cross_attn_skip_range = cross_attn_skip_range self._pab_spatial_attn_timestep_range = spatial_attn_timestep_range + self._pab_temporal_attn_timestep_range = temporal_attn_timestep_range self._pab_cross_attn_timestep_range = cross_attn_timestep_range - self._pab_enabled = spatial_attn_skip_range or cross_attn_skip_range + self._pab_spatial_attn_layer_identifiers = spatial_attn_layer_identifiers + self._pab_temporal_attn_layer_identifiers = temporal_attn_layer_identifiers + self._pab_cross_attn_layer_identifiers = cross_attn_layer_identifiers + + self._pab_enabled = spatial_attn_skip_range or temporal_attn_skip_range or cross_attn_skip_range self._enable_pyramid_attention_broadcast() @@ -208,9 +160,14 @@ def disable_pyramid_attention_broadcast(self) -> None: r"""Disables the pyramid attention broadcast sampling mechanism.""" self._pab_spatial_attn_skip_range = None + self._pab_temporal_attn_skip_range = None self._pab_cross_attn_skip_range = None self._pab_spatial_attn_timestep_range = None + self._pab_temporal_attn_timestep_range = None self._pab_cross_attn_timestep_range = None + self._pab_spatial_attn_layer_identifiers = None + self._pab_temporal_attn_layer_identifiers = None + self._pab_cross_attn_layer_identifiers = None self._pab_enabled = False @property From 6090575287ab4c7bdfddd5667712fdfc0f537569 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 13:47:08 +0100 Subject: [PATCH 13/61] merge pyramid-attention-rewrite-2 --- src/diffusers/models/hooks.py | 63 +++-- src/diffusers/pipelines/__init__.py | 10 + .../pipelines/allegro/pipeline_allegro.py | 3 +- .../pipelines/cogvideo/pipeline_cogvideox.py | 3 +- .../pipeline_cogvideox_fun_control.py | 3 +- .../pipeline_cogvideox_image2video.py | 3 +- .../pipeline_cogvideox_video2video.py | 3 +- src/diffusers/pipelines/flux/pipeline_flux.py | 5 +- .../pipelines/latte/pipeline_latte.py | 3 +- .../pyramid_attention_broadcast_utils.py | 248 ++++++++++++++++++ .../pipelines/pyramid_broadcast_utils.py | 175 ------------ tests/pipelines/cogvideo/test_cogvideox.py | 2 +- .../cogvideo/test_cogvideox_image2video.py | 2 +- .../cogvideo/test_cogvideox_video2video.py | 2 +- 14 files changed, 316 insertions(+), 209 deletions(-) create mode 100644 src/diffusers/pipelines/pyramid_attention_broadcast_utils.py delete mode 100644 src/diffusers/pipelines/pyramid_broadcast_utils.py diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index 2b4351d4a94e..3f5eb97ede85 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -13,7 +13,7 @@ # limitations under the License. import functools -from typing import Any, Callable, Dict, Tuple, Union +from typing import Any, Callable, Dict, Tuple import torch @@ -117,45 +117,72 @@ def reset_state(self, module): class PyramidAttentionBroadcastHook(ModelHook): def __init__( self, - skip_range: int, - timestep_range: Tuple[int, int], - timestep_callback: Callable[[], Union[torch.LongTensor, int]], + skip_callback: Callable[[torch.nn.Module], bool], + # skip_range: int, + # timestep_range: Tuple[int, int], + # timestep_callback: Callable[[], Union[torch.LongTensor, int]], ) -> None: super().__init__() - self.skip_range = skip_range - self.timestep_range = timestep_range - self.timestep_callback = timestep_callback + # self.skip_range = skip_range + # self.timestep_range = timestep_range + # self.timestep_callback = timestep_callback + self.skip_callback = skip_callback - self.attention_cache = None + self.cache = None self._iteration = 0 def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) - current_timestep = self.timestep_callback() - is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1] - should_compute_attention = self._iteration % self.skip_range == 0 + # current_timestep = self.timestep_callback() + # is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1] + # should_compute_attention = self._iteration % self.skip_range == 0 - if not is_within_timestep_range or should_compute_attention: - output = module._old_forward(*args, **kwargs) - else: - output = self.attention_cache + # if not is_within_timestep_range or should_compute_attention: + # output = module._old_forward(*args, **kwargs) + # else: + # output = self.attention_cache - self._iteration = self._iteration + 1 + if self.cache is not None and self.skip_callback(module): + output = self.cache + else: + output = module._old_forward(*args, **kwargs) return module._diffusers_hook.post_forward(module, output) def post_forward(self, module: torch.nn.Module, output: Any) -> Any: - self.attention_cache = output + self.cache = output return output def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: - self.attention_cache = None + self.cache = None self._iteration = 0 return module +class LayerSkipHook(ModelHook): + def __init__(self, skip_: Callable[[torch.nn.Module], bool]) -> None: + super().__init__() + + self.skip_callback = skip_ + + def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + + if self.skip_callback(module): + # We want to skip this layer, so we have to return the input of the current layer + # as output of the next layer. But at this point, we don't have information about + # the arguments required by next layer. Even if we did, order matters unless we + # always pass kwargs. But that is not the case usually with hidden_states, encoder_hidden_states, + # temb, etc. TODO(aryan): implement correctly later + output = None + else: + output = module._old_forward(*args, **kwargs) + + return module._diffusers_hook.post_forward(module, output) + + def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False): r""" Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 6d3a20511696..42f9f7294ec7 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -58,6 +58,11 @@ "StableDiffusionMixin", "ImagePipelineOutput", ] + _import_structure["pyramid_attention_broadcast_utils"] = [ + "PyramidAttentionBroadcastConfig", + "apply_pyramid_attention_broadcast", + "apply_pyramid_attention_broadcast_on_module", + ] _import_structure["deprecated"].extend( [ "PNDMPipeline", @@ -447,6 +452,11 @@ ImagePipelineOutput, StableDiffusionMixin, ) + from .pyramid_attention_broadcast_utils import ( + PyramidAttentionBroadcastConfig, + apply_pyramid_attention_broadcast, + apply_pyramid_attention_broadcast_on_module, + ) try: if not (is_torch_available() and is_librosa_available()): diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 1ca700e51a70..1941c6cc3529 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -38,7 +38,6 @@ ) from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor -from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin from .pipeline_output import AllegroPipelineOutput @@ -132,7 +131,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class AllegroPipeline(DiffusionPipeline, PyramidAttentionBroadcastMixin): +class AllegroPipeline(DiffusionPipeline): r""" Pipeline for text-to-video generation using Allegro. diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 1d8cc48934fa..c6b392393c1d 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -29,7 +29,6 @@ from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor -from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin from .pipeline_output import CogVideoXPipelineOutput @@ -138,7 +137,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin): +class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): r""" Pipeline for text-to-video generation using CogVideoX. diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 191909e5dbe3..36243784d529 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -30,7 +30,6 @@ from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor -from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin from .pipeline_output import CogVideoXPipelineOutput @@ -145,7 +144,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin): +class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): r""" Pipeline for controlled text-to-video generation using CogVideoX Fun. diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index febc67892ccd..0e74050926a1 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -34,7 +34,6 @@ ) from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor -from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin from .pipeline_output import CogVideoXPipelineOutput @@ -154,7 +153,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin): +class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): r""" Pipeline for image-to-video generation using CogVideoX. diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 342607d291cd..2ece7c223681 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -30,7 +30,6 @@ from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor -from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin from .pipeline_output import CogVideoXPipelineOutput @@ -160,7 +159,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin, PyramidAttentionBroadcastMixin): +class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): r""" Pipeline for video-to-video generation using CogVideoX. diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index ec2801625552..2e1fa5eb2cd7 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -655,6 +655,7 @@ def __call__( self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Define call parameters @@ -731,6 +732,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) @@ -771,9 +773,10 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if output_type == "latent": image = latents - else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index e9d1a69c7c64..c030e368048d 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -37,7 +37,6 @@ ) from ...utils.torch_utils import is_compiled_module, randn_tensor from ...video_processor import VideoProcessor -from ..pyramid_broadcast_utils import PyramidAttentionBroadcastMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -133,7 +132,7 @@ class LattePipelineOutput(BaseOutput): frames: torch.Tensor -class LattePipeline(DiffusionPipeline, PyramidAttentionBroadcastMixin): +class LattePipeline(DiffusionPipeline): r""" Pipeline for text-to-video generation using Latte. diff --git a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py new file mode 100644 index 000000000000..d898939534f9 --- /dev/null +++ b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py @@ -0,0 +1,248 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Callable, Optional, Protocol, Tuple + +import torch.nn as nn + +from ..models.attention_processor import Attention +from ..models.hooks import PyramidAttentionBroadcastHook, add_hook_to_module +from ..utils import logging +from .pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +_ATTENTION_CLASSES = (Attention,) + +_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") +_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = "temporal_transformer_blocks" +_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") + + +@dataclass +class PyramidAttentionBroadcastConfig: + r""" + Configuration for Pyramid Attention Broadcast. + + Args: + spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`): + The number of blocks to skip in the spatial attention layer. If `None`, the spatial attention layer + computations will not be skipped. + temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`): + The number of blocks to skip in the temporal attention layer. If `None`, the temporal attention layer + computations will not be skipped. + cross_attention_block_skip_range (`int`, *optional*, defaults to `None`): + The number of blocks to skip in the cross-attention layer. If `None`, the cross-attention layer computations + will not be skipped. + spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`): + The range of timesteps to skip in the spatial attention layer. The attention computations will be skipped + if the current timestep is within the specified range. + temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`): + The range of timesteps to skip in the temporal attention layer. The attention computations will be skipped + if the current timestep is within the specified range. + cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`): + The range of timesteps to skip in the cross-attention layer. The attention computations will be skipped if + the current timestep is within the specified range. + spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`): + The identifiers to match against the layer names to determine if the layer is a spatial attention layer. + temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`): + The identifiers to match against the layer names to determine if the layer is a temporal attention layer. + cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`): + The identifiers to match against the layer names to determine if the layer is a cross-attention layer. + """ + spatial_attention_block_skip_range: Optional[int] = None + temporal_attention_block_skip_range: Optional[int] = None + cross_attention_block_skip_range: Optional[int] = None + + spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800) + temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800) + cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800) + + spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS + cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS + + +class PyramidAttentionBroadcastState: + r""" + State for Pyramid Attention Broadcast. + + Attributes: + iteration (`int`): + The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is + called before starting a new inference forward pass for PAB to work correctly. + """ + def __init__(self) -> None: + self.iteration = 0 + + def reset_state(self): + self.iteration = 0 + + +class nnModulePAB(Protocol): + r""" + Type hint for a torch.nn.Module that contains a `_pyramid_attention_broadcast_state` attribute. + + Attributes: + _pyramid_attention_broadcast_state (`PyramidAttentionBroadcastState`): + The state of Pyramid Attention Broadcast. + """ + _pyramid_attention_broadcast_state: PyramidAttentionBroadcastState + + +def apply_pyramid_attention_broadcast( + pipeline: DiffusionPipeline, + config: Optional[PyramidAttentionBroadcastConfig] = None, + denoiser: Optional[nn.Module] = None, +): + r""" + Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline. + + PAB is an attention approximation method that leverages the similarity in attention states between timesteps to + reduce the computational cost of attention computation. The key takeaway from the paper is that the attention + similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and + spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently + than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process. + + Args: + pipeline (`DiffusionPipeline`): + The diffusion pipeline to apply Pyramid Attention Broadcast to. + config (`Optional[PyramidAttentionBroadcastConfig]`, `optional`, defaults to `None`): + The configuration to use for Pyramid Attention Broadcast. + denoiser (`Optional[nn.Module]`, `optional`, defaults to `None`): + The denoiser module to apply Pyramid Attention Broadcast to. If `None`, the pipeline's transformer or unet + module will be used. + + Example: + + ```python + >>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> config = PyramidAttentionBroadcastConfig( + ... spatial_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(100, 800) + ... ) + >>> apply_pyramid_attention_broadcast(pipe, config) + ``` + """ + # We present Pyramid Attention Broadcast (PAB), a real-time, high quality and training-free approach for DiT-based video generation. Our method is founded on the observation that attention difference in the diffusion process exhibits a U-shaped pattern, indicating significant redundancy. We mitigate this by broadcasting attention outputs to subsequent steps in a pyramid style. It applies different broadcast strategies to each attention based on their variance for best efficiency. We further introduce broadcast sequence parallel for more efficient distributed inference. PAB demonstrates superior results across three models compared to baselines, achieving real-time generation for up to 720p videos. We anticipate that our simple yet effective method will serve as a robust baseline and facilitate future research and application for video generation. + if config is None: + config = PyramidAttentionBroadcastConfig() + + if ( + config.spatial_attention_block_skip_range is None + and config.temporal_attention_block_skip_range is None + and config.cross_attention_block_skip_range is None + ): + logger.warning( + "Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` " + "or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. " + "To avoid this warning, please set one of the above parameters." + ) + config.spatial_attention_block_skip_range = 2 + + if denoiser is None: + denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet + + for name, module in denoiser.named_modules(): + if not isinstance(module, _ATTENTION_CLASSES): + continue + if isinstance(module, Attention): + _apply_pyramid_attention_broadcast_on_attention_class(pipeline, name, module, config) + + +def apply_pyramid_attention_broadcast_on_module( + module: Attention, + skip_callback: Callable[[nn.Module], bool], +): + r""" + Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module. + + Args: + module (`torch.nn.Module`): + The module to apply Pyramid Attention Broadcast to. + skip_callback (`Callable[[nn.Module], bool]`): + A callback function that determines whether the attention computation should be skipped or not. The + callback function should return a boolean value, where `True` indicates that the attention computation + should be skipped, and `False` indicates that the attention computation should not be skipped. The callback + function will receive a torch.nn.Module containing a `_pyramid_attention_broadcast_state` attribute that + can should be used to retrieve and update the state of PAB for the given module. + """ + module._pyramid_attention_broadcast_state = PyramidAttentionBroadcastState() + hook = PyramidAttentionBroadcastHook(skip_callback=skip_callback) + add_hook_to_module(module, hook, append=True) + + +def _apply_pyramid_attention_broadcast_on_attention_class( + pipeline: DiffusionPipeline, name: str, module: Attention, config: PyramidAttentionBroadcastConfig +): + # Similar check as PEFT to determine if a string layer name matches a module name + # TODO(aryan): make this regex based + is_spatial_self_attention = ( + any( + f"{identifier}." in name or identifier == name for identifier in config.spatial_attention_block_identifiers + ) + and config.spatial_attention_block_skip_range is not None + and not module.is_cross_attention + ) + is_temporal_self_attention = ( + any( + f"{identifier}." in name or identifier == name + for identifier in config.temporal_attention_block_identifiers + ) + and config.temporal_attention_block_skip_range is not None + and not module.is_cross_attention + ) + is_cross_attention = ( + any(f"{identifier}." in name or identifier == name for identifier in config.cross_attention_block_identifiers) + and config.cross_attention_block_skip_range is not None + and not module.is_cross_attention + ) + + block_skip_range, timestep_skip_range = None, None + if is_spatial_self_attention: + block_skip_range = config.spatial_attention_block_skip_range + timestep_skip_range = config.spatial_attention_timestep_skip_range + elif is_temporal_self_attention: + block_skip_range = config.temporal_attention_block_skip_range + timestep_skip_range = config.temporal_attention_timestep_skip_range + elif is_cross_attention: + block_skip_range = config.cross_attention_block_skip_range + timestep_skip_range = config.cross_attention_timestep_skip_range + + if block_skip_range is None or timestep_skip_range is None: + logger.warning(f"Unable to apply Pyramid Attention Broadcast to the selected layer: {name}.") + return + + def skip_callback(module: nnModulePAB) -> bool: + pab_state = module._pyramid_attention_broadcast_state + current_timestep = pipeline._current_timestep + is_within_timestep_range = timestep_skip_range[0] < current_timestep < timestep_skip_range[1] + + if is_within_timestep_range: + should_compute_attention = pab_state.iteration > 0 and pab_state.iteration % block_skip_range == 0 + pab_state.iteration += 1 + return not should_compute_attention + + # We are still not yet in the phase of inference where skipping attention is possible without minimal quality + # loss, as described in the paper. So, the attention computation cannot be skipped + return False + + logger.debug(f"Enabling Pyramid Attention Broadcast in layer: {name}") + apply_pyramid_attention_broadcast_on_module(module, skip_callback) diff --git a/src/diffusers/pipelines/pyramid_broadcast_utils.py b/src/diffusers/pipelines/pyramid_broadcast_utils.py deleted file mode 100644 index 7fdb6a7f5b93..000000000000 --- a/src/diffusers/pipelines/pyramid_broadcast_utils.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional, Tuple - -import torch.nn as nn - -from ..models.attention_processor import Attention -from ..models.hooks import PyramidAttentionBroadcastHook, add_hook_to_module, remove_hook_from_module -from ..utils import logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class PyramidAttentionBroadcastMixin: - r"""Mixin class for [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588).""" - - def _enable_pyramid_attention_broadcast(self) -> None: - denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet - - for name, module in denoiser.named_modules(): - if isinstance(module, Attention): - is_spatial_attention = ( - any(x in name for x in self._pab_spatial_attn_layer_identifiers) - and self._pab_spatial_attn_skip_range is not None - and not module.is_cross_attention - ) - is_temporal_attention = ( - any(x in name for x in self._pab_temporal_attn_layer_identifiers) - and self._pab_temporal_attn_skip_range is not None - and not module.is_cross_attention - ) - is_cross_attention = ( - any(x in name for x in self._pab_cross_attn_layer_identifiers) - and self._pab_cross_attn_skip_range is not None - and module.is_cross_attention - ) - - if is_spatial_attention: - skip_range = self._pab_spatial_attn_skip_range - timestep_range = self._pab_spatial_attn_timestep_range - if is_temporal_attention: - skip_range = self._pab_temporal_attn_skip_range - timestep_range = self._pab_temporal_attn_timestep_range - if is_cross_attention: - skip_range = self._pab_cross_attn_skip_range - timestep_range = self._pab_cross_attn_timestep_range - - if skip_range is None: - continue - - # logger.debug(f"Enabling Pyramid Attention Broadcast in layer: {name}") - print(f"Enabling Pyramid Attention Broadcast in layer: {name}") - - add_hook_to_module( - module, - PyramidAttentionBroadcastHook( - skip_range=skip_range, - timestep_range=timestep_range, - timestep_callback=self._pyramid_attention_broadcast_timestep_callback, - ), - append=True, - ) - - def _disable_pyramid_attention_broadcast(self) -> None: - denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet - for name, module in denoiser.named_modules(): - logger.debug(f"Disabling Pyramid Attention Broadcast in layer: {name}") - remove_hook_from_module(module) - - def _pyramid_attention_broadcast_timestep_callback(self): - return self._current_timestep - - def enable_pyramid_attention_broadcast( - self, - spatial_attn_skip_range: Optional[int] = None, - spatial_attn_timestep_range: Tuple[int, int] = (100, 800), - temporal_attn_skip_range: Optional[int] = None, - cross_attn_skip_range: Optional[int] = None, - temporal_attn_timestep_range: Tuple[int, int] = (100, 800), - cross_attn_timestep_range: Tuple[int, int] = (100, 800), - spatial_attn_layer_identifiers: List[str] = ["blocks", "transformer_blocks"], - temporal_attn_layer_identifiers: List[str] = ["temporal_transformer_blocks"], - cross_attn_layer_identifiers: List[str] = ["blocks", "transformer_blocks"], - ) -> None: - r""" - Enable pyramid attention broadcast to speedup inference by re-using attention states and skipping computation - systematically as described in the paper: [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588). - - Args: - spatial_attn_skip_range (`int`, *optional*): - The attention block to execute after skipping intermediate spatial attention blocks. If set to the - value `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have - different tolerances to how much attention computation can be reused based on the differences between - successive blocks. So, this parameter must be adjusted per model after performing experimentation. - Setting this value to `2` is recommended for different models PAB has been experimented with. - temporal_attn_skip_range (`int`, *optional*): - The attention block to execute after skipping intermediate temporal attention blocks. If set to the - value `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have - different tolerances to how much attention computation can be reused based on the differences between - successive blocks. So, this parameter must be adjusted per model after performing experimentation. - Setting this value to `4` is recommended for different models PAB has been experimented with. - cross_attn_skip_range (`int`, *optional*): - The attention block to execute after skipping intermediate cross attention blocks. If set to the value - `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have - different tolerances to how much attention computation can be reused based on the differences between - successive blocks. So, this parameter must be adjusted per model after performing experimentation. - Setting this value to `6` is recommended for different models PAB has been experimented with. - spatial_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`): - The timestep range between which PAB will remain activated in spatial attention blocks. While - activated, PAB will re-use attention computations between inference steps. - temporal_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`): - The timestep range between which PAB will remain activated in temporal attention blocks. While - activated, PAB will re-use attention computations between inference steps. - cross_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`): - The timestep range between which PAB will remain activated in cross attention blocks. While activated, - PAB will re-use attention computations between inference steps. - """ - - if spatial_attn_timestep_range[0] > spatial_attn_timestep_range[1]: - raise ValueError( - "Expected `spatial_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied." - ) - if temporal_attn_timestep_range[0] > temporal_attn_timestep_range[1]: - raise ValueError( - "Expected `temporal_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied." - ) - if cross_attn_timestep_range[0] > cross_attn_timestep_range[1]: - raise ValueError( - "Expected `cross_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied." - ) - - self._pab_spatial_attn_skip_range = spatial_attn_skip_range - self._pab_temporal_attn_skip_range = temporal_attn_skip_range - self._pab_cross_attn_skip_range = cross_attn_skip_range - self._pab_spatial_attn_timestep_range = spatial_attn_timestep_range - self._pab_temporal_attn_timestep_range = temporal_attn_timestep_range - self._pab_cross_attn_timestep_range = cross_attn_timestep_range - self._pab_spatial_attn_layer_identifiers = spatial_attn_layer_identifiers - self._pab_temporal_attn_layer_identifiers = temporal_attn_layer_identifiers - self._pab_cross_attn_layer_identifiers = cross_attn_layer_identifiers - - self._pab_enabled = spatial_attn_skip_range or temporal_attn_skip_range or cross_attn_skip_range - - self._enable_pyramid_attention_broadcast() - - def disable_pyramid_attention_broadcast(self) -> None: - r"""Disables the pyramid attention broadcast sampling mechanism.""" - - self._pab_spatial_attn_skip_range = None - self._pab_temporal_attn_skip_range = None - self._pab_cross_attn_skip_range = None - self._pab_spatial_attn_timestep_range = None - self._pab_temporal_attn_timestep_range = None - self._pab_cross_attn_timestep_range = None - self._pab_spatial_attn_layer_identifiers = None - self._pab_temporal_attn_layer_identifiers = None - self._pab_cross_attn_layer_identifiers = None - self._pab_enabled = False - - @property - def pyramid_attention_broadcast_enabled(self): - return hasattr(self, "_pab_enabled") and self._pab_enabled diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index bbafd3aa7ae2..775a351fc2ac 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -21,7 +21,7 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler -from diffusers.pipelines.pyramid_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper +from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper from diffusers.utils.testing_utils import ( enable_full_determinism, numpy_cosine_similarity_distance, diff --git a/tests/pipelines/cogvideo/test_cogvideox_image2video.py b/tests/pipelines/cogvideo/test_cogvideox_image2video.py index c83470cf6a2d..a5f0b3211924 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_image2video.py +++ b/tests/pipelines/cogvideo/test_cogvideox_image2video.py @@ -22,7 +22,7 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler -from diffusers.pipelines.pyramid_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper +from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, diff --git a/tests/pipelines/cogvideo/test_cogvideox_video2video.py b/tests/pipelines/cogvideo/test_cogvideox_video2video.py index 9f05b04d77af..a717074e0161 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_video2video.py +++ b/tests/pipelines/cogvideo/test_cogvideox_video2video.py @@ -21,7 +21,7 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXVideoToVideoPipeline, DDIMScheduler -from diffusers.pipelines.pyramid_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper +from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS From 903514ffccc7f9a87f0ba61372711beb83233ec0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 13:52:58 +0100 Subject: [PATCH 14/61] make style --- .../pyramid_attention_broadcast_utils.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py index d898939534f9..49aa74a6fc86 100644 --- a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py @@ -40,23 +40,26 @@ class PyramidAttentionBroadcastConfig: Args: spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`): - The number of blocks to skip in the spatial attention layer. If `None`, the spatial attention layer - computations will not be skipped. + The number of times a specific spatial attention broadcast is skipped before computing the attention states + to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., + old attention states will be re-used) before computing the new attention states again. temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`): - The number of blocks to skip in the temporal attention layer. If `None`, the temporal attention layer - computations will not be skipped. + The number of times a specific temporal attention broadcast is skipped before computing the attention + states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times + (i.e., old attention states will be re-used) before computing the new attention states again. cross_attention_block_skip_range (`int`, *optional*, defaults to `None`): - The number of blocks to skip in the cross-attention layer. If `None`, the cross-attention layer computations - will not be skipped. + The number of times a specific cross-attention broadcast is skipped before computing the attention states + to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., + old attention states will be re-used) before computing the new attention states again. spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`): - The range of timesteps to skip in the spatial attention layer. The attention computations will be skipped - if the current timestep is within the specified range. + The range of timesteps to skip in the spatial attention layer. The attention computations will be + conditionally skipped if the current timestep is within the specified range. temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`): - The range of timesteps to skip in the temporal attention layer. The attention computations will be skipped - if the current timestep is within the specified range. + The range of timesteps to skip in the temporal attention layer. The attention computations will be + conditionally skipped if the current timestep is within the specified range. cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`): - The range of timesteps to skip in the cross-attention layer. The attention computations will be skipped if - the current timestep is within the specified range. + The range of timesteps to skip in the cross-attention layer. The attention computations will be + conditionally skipped if the current timestep is within the specified range. spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`): The identifiers to match against the layer names to determine if the layer is a spatial attention layer. temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`): @@ -64,6 +67,7 @@ class PyramidAttentionBroadcastConfig: cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`): The identifiers to match against the layer names to determine if the layer is a cross-attention layer. """ + spatial_attention_block_skip_range: Optional[int] = None temporal_attention_block_skip_range: Optional[int] = None cross_attention_block_skip_range: Optional[int] = None @@ -86,6 +90,7 @@ class PyramidAttentionBroadcastState: The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is called before starting a new inference forward pass for PAB to work correctly. """ + def __init__(self) -> None: self.iteration = 0 @@ -101,6 +106,7 @@ class nnModulePAB(Protocol): _pyramid_attention_broadcast_state (`PyramidAttentionBroadcastState`): The state of Pyramid Attention Broadcast. """ + _pyramid_attention_broadcast_state: PyramidAttentionBroadcastState From b690db2951ab39a390e58202c6fb2931c2d2f846 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 13:54:05 +0100 Subject: [PATCH 15/61] remove changes from latte transformer --- .../transformers/latte_transformer_3d.py | 60 ------------------- 1 file changed, 60 deletions(-) diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index e567760d8aa9..bc9a5bfcfd12 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -167,66 +167,6 @@ def __init__( def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - def forward( self, hidden_states: torch.Tensor, From 63ab886a43950f1b825373cc5ee31673b4837956 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 13:55:18 +0100 Subject: [PATCH 16/61] revert docs changes --- docs/source/en/api/pipelines/cogvideox.md | 41 +---------------------- docs/source/en/api/pipelines/latte.md | 33 +----------------- 2 files changed, 2 insertions(+), 72 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 1a341a9f4c01..c29d60fcc72b 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -15,7 +15,7 @@ # CogVideoX -[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://huggingface.co/papers/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang. +[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://arxiv.org/abs/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang. The abstract from the paper is: @@ -120,45 +120,6 @@ It is also worth noting that torchao quantization is fully compatible with [torc - [torchao](https://gist.github.com/a-r-r-o-w/4d9732d17412888c885480c6521a9897) - [quanto](https://gist.github.com/a-r-r-o-w/31be62828b00a9292821b85c1017effa) -### Pyramid Attention Broadcast - -[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You. - -Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states aren't that different between successive steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation. - -Enable PAB with [`~PyramidAttentionBroadcastMixin.enable_pyramind_attention_broadcast`] on any pipeline and keep track of the current inference timestep in the pipeline. - -```python -import torch -from diffusers import CogVideoXPipeline -from diffusers.utils import export_to_video - -pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.float16) -pipe.to("cuda") - -pipe.enable_pyramid_attention_broadcast( - spatial_attn_skip_range=2, - spatial_attn_timestep_range=[100, 850], -) - -prompt = ( - "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " - "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " - "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " - "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " - "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " - "atmosphere of this unique musical performance." -) -video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] -export_to_video(video, "output.mp4", fps=8) -``` - -| model | model_memory | normal_memory | pab_memory | normal_time | pab_time | speedup | -|:----------:|:--------------:|:---------------:|:------------:|:-------------:|:----------:|:---------:| -| Cog-2b T2V | 12.55 | 35.342 | 35.342 | 86.915 | 63.914 | 1.359 | -| Cog-5b T2V | 19.66 | 40.945 | 40.945 | 246.152 | 168.59 | 1.460 | -| Cog-5b I2V | 19.764 | 42.74 | 42.74 | 246.867 | 170.111 | 1.451 | - ## CogVideoXPipeline [[autodoc]] CogVideoXPipeline diff --git a/docs/source/en/api/pipelines/latte.md b/docs/source/en/api/pipelines/latte.md index bc62aadbc5ee..c2154d5d47c1 100644 --- a/docs/source/en/api/pipelines/latte.md +++ b/docs/source/en/api/pipelines/latte.md @@ -16,7 +16,7 @@ ![latte text-to-video](https://github.com/Vchitect/Latte/blob/52bc0029899babbd6e9250384c83d8ed2670ff7a/visuals/latte.gif?raw=true) -[Latte: Latent Diffusion Transformer for Video Generation](https://huggingface.co/papers/2401.03048) from Monash University, Shanghai AI Lab, Nanjing University, and Nanyang Technological University. +[Latte: Latent Diffusion Transformer for Video Generation](https://arxiv.org/abs/2401.03048) from Monash University, Shanghai AI Lab, Nanjing University, and Nanyang Technological University. The abstract from the paper is: @@ -70,37 +70,6 @@ Without torch.compile(): Average inference time: 16.246 seconds. With torch.compile(): Average inference time: 14.573 seconds. ``` -### Pyramid Attention Broadcast - -[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You. - -Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps, and re-using cached attention states. This is due to the realization that the attention states do not differ too much numerically between successive steps. This difference is most significant/prominent in the spatial attention blocks, lesser so in temporal attention blocks, and least in cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by temporal and spatial attention blocks. By combining other techniques like Sequence Parallelism and CFG Parallelism, the authors achieve near real-time video generation. - -PAB can be enabled easily on any pipeline by deriving from the [`PyramidAttentionBroadcastMixin`] and keeping track of current inference timestep in the pipeline. Minimal example to demonstrate how to use PAB with Latte: - -```python -import torch -from diffusers import LattePipeline -from diffusers.utils import export_to_gif - -pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16) - -pipe.enable_pyramid_attention_broadcast( - spatial_attn_skip_range=2, - cross_attn_skip_range=6, - spatial_attn_timestep_range=[100, 800], - cross_attn_timestep_range=[100, 800], -) - -prompt = "A small cactus with a happy face in the Sahara desert." -videos = pipe(prompt).frames[0] -export_to_gif(videos, "latte.gif") -``` - -| model | model_memory | normal_memory | pab_memory | normal_time | pab_time | speedup | -|:----------:|:--------------:|:---------------:|:------------:|:-------------:|:----------:|:---------:| -| Latte | 11.007 | 25.594 | 25.594 | 28.026 | 24.073 | 1.164 | - ## LattePipeline [[autodoc]] LattePipeline From d40bcedad7430c9b06e58edcc6a5e5aa51f6cab3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 13:59:54 +0100 Subject: [PATCH 17/61] better debug message --- .../pipelines/pyramid_attention_broadcast_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py index 49aa74a6fc86..8349e0b8a10d 100644 --- a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py @@ -221,16 +221,19 @@ def _apply_pyramid_attention_broadcast_on_attention_class( and not module.is_cross_attention ) - block_skip_range, timestep_skip_range = None, None + block_skip_range, timestep_skip_range, block_type = None, None, None if is_spatial_self_attention: block_skip_range = config.spatial_attention_block_skip_range timestep_skip_range = config.spatial_attention_timestep_skip_range + block_type = "spatial" elif is_temporal_self_attention: block_skip_range = config.temporal_attention_block_skip_range timestep_skip_range = config.temporal_attention_timestep_skip_range + block_type = "temporal" elif is_cross_attention: block_skip_range = config.cross_attention_block_skip_range timestep_skip_range = config.cross_attention_timestep_skip_range + block_type = "cross" if block_skip_range is None or timestep_skip_range is None: logger.warning(f"Unable to apply Pyramid Attention Broadcast to the selected layer: {name}.") @@ -250,5 +253,5 @@ def skip_callback(module: nnModulePAB) -> bool: # loss, as described in the paper. So, the attention computation cannot be skipped return False - logger.debug(f"Enabling Pyramid Attention Broadcast in layer: {name}") + logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}") apply_pyramid_attention_broadcast_on_module(module, skip_callback) From 0ea904ec402380377fcea8027cd1b6ab1a53183e Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 14:06:10 +0100 Subject: [PATCH 18/61] add todos for future --- src/diffusers/pipelines/pyramid_attention_broadcast_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py index 8349e0b8a10d..b074243bd07d 100644 --- a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py @@ -80,6 +80,9 @@ class PyramidAttentionBroadcastConfig: temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS + # TODO(aryan): add PAB for MLP layers (very limited speedup from testing with original codebase + # so not added for now) + class PyramidAttentionBroadcastState: r""" From 9d452dc401974ad9bdfacbd54998aa9f13477277 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 14:14:48 +0100 Subject: [PATCH 19/61] update tests --- tests/pipelines/cogvideo/test_cogvideox.py | 30 +++++-------------- .../cogvideo/test_cogvideox_image2video.py | 30 +++++-------------- .../cogvideo/test_cogvideox_video2video.py | 30 +++++-------------- 3 files changed, 21 insertions(+), 69 deletions(-) diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index 775a351fc2ac..495341b69b77 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -21,7 +21,10 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler -from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper +from diffusers.pipelines.pyramid_attention_broadcast_utils import ( + PyramidAttentionBroadcastConfig, + apply_pyramid_attention_broadcast, +) from diffusers.utils.testing_utils import ( enable_full_determinism, numpy_cosine_similarity_distance, @@ -333,40 +336,21 @@ def test_pyramid_attention_broadcast(self): frames = pipe(**inputs).frames # [B, F, C, H, W] original_image_slice = frames[0, -2:, -1, -3:, -3:] - pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800)) - assert pipe.pyramid_attention_broadcast_enabled - - num_pab_processors = sum( - [ - isinstance(processor, PyramidAttentionBroadcastAttentionProcessorWrapper) - for processor in pipe.transformer.attn_processors.values() - ] + config = PyramidAttentionBroadcastConfig( + spatial_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(100, 800) ) - assert num_pab_processors == num_layers + apply_pyramid_attention_broadcast(pipe, config) inputs = self.get_dummy_inputs(device) inputs["num_inference_steps"] = 4 frames = pipe(**inputs).frames image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:] - pipe.disable_pyramid_attention_broadcast() - assert not pipe.pyramid_attention_broadcast_enabled - - inputs = self.get_dummy_inputs(device) - frames = pipe(**inputs).frames - image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:] - # We need to use higher tolerance because we are using a random model. With a converged/trained # model, the tolerance can be lower. assert np.allclose( original_image_slice, image_slice_pab_enabled, atol=0.2 ), "PAB outputs should not differ much in specified timestep range." - assert np.allclose( - image_slice_pab_enabled, image_slice_pab_disabled, atol=0.2 - ), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range." - assert np.allclose( - original_image_slice, image_slice_pab_disabled, atol=0.2 - ), "Original outputs should match when PAB is disabled." @slow diff --git a/tests/pipelines/cogvideo/test_cogvideox_image2video.py b/tests/pipelines/cogvideo/test_cogvideox_image2video.py index a5f0b3211924..03422e586156 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_image2video.py +++ b/tests/pipelines/cogvideo/test_cogvideox_image2video.py @@ -22,7 +22,10 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler -from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper +from diffusers.pipelines.pyramid_attention_broadcast_utils import ( + PyramidAttentionBroadcastConfig, + apply_pyramid_attention_broadcast, +) from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -356,40 +359,21 @@ def test_pyramid_attention_broadcast(self): frames = pipe(**inputs).frames # [B, F, C, H, W] original_image_slice = frames[0, -2:, -1, -3:, -3:] - pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800)) - assert pipe.pyramid_attention_broadcast_enabled - - num_pab_processors = sum( - [ - isinstance(processor, PyramidAttentionBroadcastAttentionProcessorWrapper) - for processor in pipe.transformer.attn_processors.values() - ] + config = PyramidAttentionBroadcastConfig( + spatial_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(100, 800) ) - assert num_pab_processors == num_layers + apply_pyramid_attention_broadcast(pipe, config) inputs = self.get_dummy_inputs(device) inputs["num_inference_steps"] = 4 frames = pipe(**inputs).frames image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:] - pipe.disable_pyramid_attention_broadcast() - assert not pipe.pyramid_attention_broadcast_enabled - - inputs = self.get_dummy_inputs(device) - frames = pipe(**inputs).frames - image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:] - # We need to use higher tolerance because we are using a random model. With a converged/trained # model, the tolerance can be lower. assert np.allclose( original_image_slice, image_slice_pab_enabled, atol=0.2 ), "PAB outputs should not differ much in specified timestep range." - assert np.allclose( - image_slice_pab_enabled, image_slice_pab_disabled, atol=0.2 - ), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range." - assert np.allclose( - original_image_slice, image_slice_pab_disabled, atol=0.2 - ), "Original outputs should match when PAB is disabled." @slow diff --git a/tests/pipelines/cogvideo/test_cogvideox_video2video.py b/tests/pipelines/cogvideo/test_cogvideox_video2video.py index a717074e0161..bfe1bc835c4d 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_video2video.py +++ b/tests/pipelines/cogvideo/test_cogvideox_video2video.py @@ -21,7 +21,10 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXVideoToVideoPipeline, DDIMScheduler -from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper +from diffusers.pipelines.pyramid_attention_broadcast_utils import ( + PyramidAttentionBroadcastConfig, + apply_pyramid_attention_broadcast, +) from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS @@ -338,37 +341,18 @@ def test_pyramid_attention_broadcast(self): frames = pipe(**inputs).frames # [B, F, C, H, W] original_image_slice = frames[0, -2:, -1, -3:, -3:] - pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800)) - assert pipe.pyramid_attention_broadcast_enabled - - num_pab_processors = sum( - [ - isinstance(processor, PyramidAttentionBroadcastAttentionProcessorWrapper) - for processor in pipe.transformer.attn_processors.values() - ] + config = PyramidAttentionBroadcastConfig( + spatial_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(100, 800) ) - assert num_pab_processors == num_layers + apply_pyramid_attention_broadcast(pipe, config) inputs = self.get_dummy_inputs(device) inputs["num_inference_steps"] = 4 frames = pipe(**inputs).frames image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:] - pipe.disable_pyramid_attention_broadcast() - assert not pipe.pyramid_attention_broadcast_enabled - - inputs = self.get_dummy_inputs(device) - frames = pipe(**inputs).frames - image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:] - # We need to use higher tolerance because we are using a random model. With a converged/trained # model, the tolerance can be lower. assert np.allclose( original_image_slice, image_slice_pab_enabled, atol=0.2 ), "PAB outputs should not differ much in specified timestep range." - assert np.allclose( - image_slice_pab_enabled, image_slice_pab_disabled, atol=0.2 - ), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range." - assert np.allclose( - original_image_slice, image_slice_pab_disabled, atol=0.2 - ), "Original outputs should match when PAB is disabled." From cfe392162847e5178fe2d639a10d292a96a2fcb7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 14:17:26 +0100 Subject: [PATCH 20/61] make style --- src/diffusers/models/transformers/latte_transformer_3d.py | 3 +-- src/diffusers/pipelines/pyramid_attention_broadcast_utils.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index bc9a5bfcfd12..72473cd71f9d 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Union +from typing import Optional import torch from torch import nn @@ -20,7 +20,6 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid from ..attention import BasicTransformerBlock -from ..attention_processor import AttentionProcessor from ..embeddings import PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin diff --git a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py index b074243bd07d..cebf1469500c 100644 --- a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py @@ -150,7 +150,6 @@ def apply_pyramid_attention_broadcast( >>> apply_pyramid_attention_broadcast(pipe, config) ``` """ - # We present Pyramid Attention Broadcast (PAB), a real-time, high quality and training-free approach for DiT-based video generation. Our method is founded on the observation that attention difference in the diffusion process exhibits a U-shaped pattern, indicating significant redundancy. We mitigate this by broadcasting attention outputs to subsequent steps in a pyramid style. It applies different broadcast strategies to each attention based on their variance for best efficiency. We further introduce broadcast sequence parallel for more efficient distributed inference. PAB demonstrates superior results across three models compared to baselines, achieving real-time generation for up to 720p videos. We anticipate that our simple yet effective method will serve as a robust baseline and facilitate future research and application for video generation. if config is None: config = PyramidAttentionBroadcastConfig() From b972c4b12186335fcd2960aa8817a00f3de2b763 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 14:29:38 +0100 Subject: [PATCH 21/61] cleanup --- src/diffusers/models/hooks.py | 20 +------------------ .../pyramid_attention_broadcast_utils.py | 2 +- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index 3f5eb97ede85..eb14577757f7 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -115,18 +115,9 @@ def reset_state(self, module): class PyramidAttentionBroadcastHook(ModelHook): - def __init__( - self, - skip_callback: Callable[[torch.nn.Module], bool], - # skip_range: int, - # timestep_range: Tuple[int, int], - # timestep_callback: Callable[[], Union[torch.LongTensor, int]], - ) -> None: + def __init__(self, skip_callback: Callable[[torch.nn.Module], bool]) -> None: super().__init__() - # self.skip_range = skip_range - # self.timestep_range = timestep_range - # self.timestep_callback = timestep_callback self.skip_callback = skip_callback self.cache = None @@ -135,15 +126,6 @@ def __init__( def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) - # current_timestep = self.timestep_callback() - # is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1] - # should_compute_attention = self._iteration % self.skip_range == 0 - - # if not is_within_timestep_range or should_compute_attention: - # output = module._old_forward(*args, **kwargs) - # else: - # output = self.attention_cache - if self.cache is not None and self.skip_callback(module): output = self.cache else: diff --git a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py index cebf1469500c..44c692bc8832 100644 --- a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py @@ -251,7 +251,7 @@ def skip_callback(module: nnModulePAB) -> bool: pab_state.iteration += 1 return not should_compute_attention - # We are still not yet in the phase of inference where skipping attention is possible without minimal quality + # We are still not in the phase of inference where skipping attention is possible without minimal quality # loss, as described in the paper. So, the attention computation cannot be skipped return False From 2b558ffa4b42eb38fa84c3576865162022102a5d Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 15:03:20 +0100 Subject: [PATCH 22/61] fix --- src/diffusers/pipelines/pyramid_attention_broadcast_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py index 44c692bc8832..581663a06ad1 100644 --- a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py @@ -29,7 +29,7 @@ _ATTENTION_CLASSES = (Attention,) _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") -_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = "temporal_transformer_blocks" +_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) _CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") From 0b2629db040f4ff242e0681b3279f6121f5a1e1b Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 15:13:06 +0100 Subject: [PATCH 23/61] improve log message; fix latte test --- .../pyramid_attention_broadcast_utils.py | 10 ++++-- tests/pipelines/latte/test_latte.py | 34 +++++++------------ 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py index 581663a06ad1..06a9f19ce027 100644 --- a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py @@ -220,7 +220,7 @@ def _apply_pyramid_attention_broadcast_on_attention_class( is_cross_attention = ( any(f"{identifier}." in name or identifier == name for identifier in config.cross_attention_block_identifiers) and config.cross_attention_block_skip_range is not None - and not module.is_cross_attention + and module.is_cross_attention ) block_skip_range, timestep_skip_range, block_type = None, None, None @@ -238,7 +238,13 @@ def _apply_pyramid_attention_broadcast_on_attention_class( block_type = "cross" if block_skip_range is None or timestep_skip_range is None: - logger.warning(f"Unable to apply Pyramid Attention Broadcast to the selected layer: {name}.") + logger.info( + f'Unable to apply Pyramid Attention Broadcast to the selected layer: "{name}" because it does ' + f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, " + f"however, that this layer may still be valid for applying PAB. Please specify the correct " + f"block identifiers in the configuration or use the specialized `apply_pyramid_attention_broadcast_on_module` " + f"function to apply PAB to this layer." + ) return def skip_callback(module: nnModulePAB) -> bool: diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index 75f20a446dfd..fb08f468a4a3 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -22,11 +22,10 @@ import torch from transformers import AutoTokenizer, T5EncoderModel -from diffusers import ( - AutoencoderKL, - DDIMScheduler, - LattePipeline, - LatteTransformer3DModel, +from diffusers import AutoencoderKL, DDIMScheduler, LattePipeline, LatteTransformer3DModel +from diffusers.pipelines.pyramid_attention_broadcast_utils import ( + PyramidAttentionBroadcastConfig, + apply_pyramid_attention_broadcast, ) from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( @@ -277,33 +276,24 @@ def test_pyramid_attention_broadcast(self): frames = pipe(**inputs).frames # [B, F, C, H, W] original_image_slice = frames[0, -2:, -1, -3:, -3:] - pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800)) - assert pipe.pyramid_attention_broadcast_enabled + config = PyramidAttentionBroadcastConfig( + spatial_attention_block_skip_range=2, + temporal_attention_block_skip_range=3, + spatial_attention_timestep_skip_range=(100, 800), + temporal_attention_timestep_skip_range=(100, 800), + ) + apply_pyramid_attention_broadcast(pipe, config) inputs = self.get_dummy_inputs(device) inputs["num_inference_steps"] = 4 frames = pipe(**inputs).frames image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:] - pipe.disable_pyramid_attention_broadcast() - assert not pipe.pyramid_attention_broadcast_enabled - - inputs = self.get_dummy_inputs(device) - frames = pipe(**inputs).frames - image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:] - # We need to use higher tolerance because we are using a random model. With a converged/trained # model, the tolerance can be lower. assert np.allclose( - original_image_slice, image_slice_pab_enabled, atol=0.25 + original_image_slice, image_slice_pab_enabled, atol=0.2 ), "PAB outputs should not differ much in specified timestep range." - print((image_slice_pab_disabled - image_slice_pab_enabled).abs().max()) - assert np.allclose( - image_slice_pab_enabled, image_slice_pab_disabled, atol=0.25 - ), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range." - assert np.allclose( - original_image_slice, image_slice_pab_disabled, atol=0.25 - ), "Original outputs should match when PAB is disabled." @slow From 9182f57c49bb2ef8271eead7346985ba9ec7edf0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 21:00:16 +0100 Subject: [PATCH 24/61] refactor --- src/diffusers/models/hooks.py | 61 +------------------ src/diffusers/pipelines/pipeline_utils.py | 4 -- .../pyramid_attention_broadcast_utils.py | 54 ++++++++++------ 3 files changed, 36 insertions(+), 83 deletions(-) diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index eb14577757f7..0a1096e6e5f8 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -13,7 +13,7 @@ # limitations under the License. import functools -from typing import Any, Callable, Dict, Tuple +from typing import Any, Dict, Tuple import torch @@ -78,9 +78,6 @@ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: """ return module - def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: - return module - class SequentialHook(ModelHook): r"""A hook that can contain several hooks and iterates through them at each event.""" @@ -108,62 +105,6 @@ def detach_hook(self, module): module = hook.detach_hook(module) return module - def reset_state(self, module): - for hook in self.hooks: - module = hook.reset_state(module) - return module - - -class PyramidAttentionBroadcastHook(ModelHook): - def __init__(self, skip_callback: Callable[[torch.nn.Module], bool]) -> None: - super().__init__() - - self.skip_callback = skip_callback - - self.cache = None - self._iteration = 0 - - def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: - args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) - - if self.cache is not None and self.skip_callback(module): - output = self.cache - else: - output = module._old_forward(*args, **kwargs) - - return module._diffusers_hook.post_forward(module, output) - - def post_forward(self, module: torch.nn.Module, output: Any) -> Any: - self.cache = output - return output - - def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: - self.cache = None - self._iteration = 0 - return module - - -class LayerSkipHook(ModelHook): - def __init__(self, skip_: Callable[[torch.nn.Module], bool]) -> None: - super().__init__() - - self.skip_callback = skip_ - - def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: - args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) - - if self.skip_callback(module): - # We want to skip this layer, so we have to return the input of the current layer - # as output of the next layer. But at this point, we don't have information about - # the arguments required by next layer. Even if we did, order matters unless we - # always pass kwargs. But that is not the case usually with hidden_states, encoder_hidden_states, - # temb, etc. TODO(aryan): implement correctly later - output = None - else: - output = module._old_forward(*args, **kwargs) - - return module._diffusers_hook.post_forward(module, output) - def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False): r""" diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 6725fe49dfc1..a504184ea2f2 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1088,10 +1088,6 @@ def maybe_free_model_hooks(self): is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions correctly when applying enable_model_cpu_offload. """ - - if hasattr(self, "_diffusers_hook"): - self._diffusers_hook.reset_state() - if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: # `enable_model_cpu_offload` has not be called, so silently do nothing return diff --git a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py index 06a9f19ce027..895ec8e84cbd 100644 --- a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py @@ -13,12 +13,12 @@ # limitations under the License. from dataclasses import dataclass -from typing import Callable, Optional, Protocol, Tuple +from typing import Any, Callable, Optional, Tuple import torch.nn as nn from ..models.attention_processor import Attention -from ..models.hooks import PyramidAttentionBroadcastHook, add_hook_to_module +from ..models.hooks import ModelHook, add_hook_to_module from ..utils import logging from .pipeline_utils import DiffusionPipeline @@ -28,7 +28,7 @@ _ATTENTION_CLASSES = (Attention,) -_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") +_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks") _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) _CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") @@ -96,21 +96,15 @@ class PyramidAttentionBroadcastState: def __init__(self) -> None: self.iteration = 0 + self.cache = None + + def update_state(self, output: Any) -> None: + self.iteration += 1 + self.cache = output def reset_state(self): self.iteration = 0 - - -class nnModulePAB(Protocol): - r""" - Type hint for a torch.nn.Module that contains a `_pyramid_attention_broadcast_state` attribute. - - Attributes: - _pyramid_attention_broadcast_state (`PyramidAttentionBroadcastState`): - The state of Pyramid Attention Broadcast. - """ - - _pyramid_attention_broadcast_state: PyramidAttentionBroadcastState + self.cache = None def apply_pyramid_attention_broadcast( @@ -247,14 +241,15 @@ def _apply_pyramid_attention_broadcast_on_attention_class( ) return - def skip_callback(module: nnModulePAB) -> bool: + def skip_callback(module: nn.Module) -> bool: pab_state = module._pyramid_attention_broadcast_state - current_timestep = pipeline._current_timestep - is_within_timestep_range = timestep_skip_range[0] < current_timestep < timestep_skip_range[1] + if pab_state.cache is None: + return False + + is_within_timestep_range = timestep_skip_range[0] < pipeline._current_timestep < timestep_skip_range[1] if is_within_timestep_range: should_compute_attention = pab_state.iteration > 0 and pab_state.iteration % block_skip_range == 0 - pab_state.iteration += 1 return not should_compute_attention # We are still not in the phase of inference where skipping attention is possible without minimal quality @@ -263,3 +258,24 @@ def skip_callback(module: nnModulePAB) -> bool: logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}") apply_pyramid_attention_broadcast_on_module(module, skip_callback) + + +class PyramidAttentionBroadcastHook(ModelHook): + def __init__(self, skip_callback: Callable[[nn.Module], bool]) -> None: + super().__init__() + + self.skip_callback = skip_callback + + def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + + if self.skip_callback(module): + output = module._pyramid_attention_broadcast_state.cache + else: + output = module._old_forward(*args, **kwargs) + + return module._diffusers_hook.post_forward(module, output) + + def post_forward(self, module: nn.Module, output: Any) -> Any: + module._pyramid_attention_broadcast_state.update_state(output) + return output From 62b5b8dde53056a3e35a6914709eb93e375139d0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 30 Dec 2024 15:15:01 +0100 Subject: [PATCH 25/61] update --- src/diffusers/models/hooks.py | 34 ++++++++- .../pipelines/cogvideo/pipeline_cogvideox.py | 2 + .../pipeline_cogvideox_fun_control.py | 2 + .../pipeline_cogvideox_image2video.py | 2 + .../pipeline_cogvideox_video2video.py | 2 + src/diffusers/pipelines/flux/pipeline_flux.py | 5 +- .../hunyuan_video/pipeline_hunyuan_video.py | 6 ++ .../pipelines/latte/pipeline_latte.py | 2 + .../pipelines/mochi/pipeline_mochi.py | 13 +++- .../pyramid_attention_broadcast_utils.py | 75 ++++++++++--------- 10 files changed, 101 insertions(+), 42 deletions(-) diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index 0a1096e6e5f8..e3e976ddb849 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -21,10 +21,11 @@ # Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py class ModelHook: r""" - A hook that contains callbacks to be executed just before and after the forward method of a model. The difference - with PyTorch existing hooks is that they get passed along the kwargs. + A hook that contains callbacks to be executed just before and after the forward method of a model. """ + _is_stateful = False + def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is initialized. @@ -78,6 +79,10 @@ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: """ return module + def reset_state(self, module: torch.nn.Module): + if self._is_stateful: + raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") + class SequentialHook(ModelHook): r"""A hook that can contain several hooks and iterates through them at each event.""" @@ -105,8 +110,13 @@ def detach_hook(self, module): module = hook.detach_hook(module) return module + def reset_state(self, module): + for hook in self.hooks: + if hook._is_stateful: + hook.reset_state(module) + -def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False): +def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False) -> torch.nn.Module: r""" Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove this behavior and restore the original `forward` method, use `remove_hook_from_module`. @@ -199,3 +209,21 @@ def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> t remove_hook_from_module(child, recurse) return module + + +def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False): + """ + Resets the state of all stateful hooks attached to a module. + + Args: + module (`torch.nn.Module`): + The module to reset the stateful hooks from. + """ + if hasattr(module, "_diffusers_hook") and ( + module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook) + ): + module._diffusers_hook.reset_state(module) + + if recurse: + for child in module.children(): + reset_stateful_hooks(child, recurse) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 312155c816fa..112b4c132261 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -24,6 +24,7 @@ from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed +from ...models.hooks import reset_stateful_hooks from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import logging, replace_example_docstring @@ -769,6 +770,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 7732a7ff1433..361901dcb37e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -25,6 +25,7 @@ from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed +from ...models.hooks import reset_stateful_hooks from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import KarrasDiffusionSchedulers from ...utils import logging, replace_example_docstring @@ -822,6 +823,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 49d556d68e37..81e01412086d 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -26,6 +26,7 @@ from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed +from ...models.hooks import reset_stateful_hooks from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import ( @@ -882,6 +883,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 74da56f597ae..ab001ba6a6b0 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -25,6 +25,7 @@ from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed +from ...models.hooks import reset_stateful_hooks from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import logging, replace_example_docstring @@ -848,6 +849,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index b5b128a43e48..819b0cd5c7c8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -28,8 +28,8 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin -from ...models.autoencoders import AutoencoderKL -from ...models.transformers import FluxTransformer2DModel +from ...models import AutoencoderKL, FluxTransformer2DModel +from ...models.hooks import reset_stateful_hooks from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, @@ -953,6 +953,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 3b0956a32da3..e2200ef39e3e 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -22,6 +22,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import HunyuanVideoLoraLoaderMixin from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from ...models.hooks import reset_stateful_hooks from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor @@ -573,6 +574,7 @@ def __call__( self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False device = self._execution_device @@ -640,6 +642,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = latents.to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) @@ -671,6 +674,8 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + self._current_timestep = None + if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] @@ -680,6 +685,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index c030e368048d..1c992648379c 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -25,6 +25,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AutoencoderKL, LatteTransformer3DModel +from ...models.hooks import reset_stateful_hooks from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -848,6 +849,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index aac4e32e33f0..7899ab5f409c 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -21,8 +21,8 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import Mochi1LoraLoaderMixin -from ...models.autoencoders import AutoencoderKL -from ...models.transformers import MochiTransformer3DModel +from ...models import AutoencoderKLHunyuanVideo, MochiTransformer3DModel +from ...models.hooks import reset_stateful_hooks from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( is_torch_xla_available, @@ -184,7 +184,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin): def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKL, + vae: AutoencoderKLHunyuanVideo, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, transformer: MochiTransformer3DModel, @@ -604,6 +604,7 @@ def __call__( self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Define call parameters @@ -673,6 +674,9 @@ def __call__( if self.interrupt: continue + # Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need + # to make sure we're using the correct non-reversed timestep values. + self._current_timestep = 1000 - t latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) @@ -718,6 +722,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if output_type == "latent": video = latents else: @@ -741,6 +747,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py index 895ec8e84cbd..bf08c69f130a 100644 --- a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from dataclasses import dataclass from typing import Any, Callable, Optional, Tuple @@ -88,21 +89,30 @@ class PyramidAttentionBroadcastState: r""" State for Pyramid Attention Broadcast. + Args: + skip_callback (`Callable[[nn.Module], bool]`): + A callback function that determines whether the attention computation should be skipped or not. The + callback function should return a boolean value, where `True` indicates that the attention computation + should be skipped, and `False` indicates that the attention computation should not be skipped. The callback + function will receive a torch.nn.Module containing a `_pyramid_attention_broadcast_state` attribute that + can should be used to retrieve and update the state of PAB for the given module. + Attributes: iteration (`int`): The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is called before starting a new inference forward pass for PAB to work correctly. + cache (`Any`): + The cached output from the previous forward pass. This is used to re-use the attention states when the + attention computation is skipped. It is either a tensor or a tuple of tensors, depending on the module. """ - def __init__(self) -> None: + def __init__(self, skip_callback: Callable[[nn.Module], bool]) -> None: + self.skip_callback = skip_callback + self.iteration = 0 self.cache = None - def update_state(self, output: Any) -> None: - self.iteration += 1 - self.cache = output - - def reset_state(self): + def reset(self): self.iteration = 0 self.cache = None @@ -186,33 +196,26 @@ def apply_pyramid_attention_broadcast_on_module( function will receive a torch.nn.Module containing a `_pyramid_attention_broadcast_state` attribute that can should be used to retrieve and update the state of PAB for the given module. """ - module._pyramid_attention_broadcast_state = PyramidAttentionBroadcastState() - hook = PyramidAttentionBroadcastHook(skip_callback=skip_callback) + module._pyramid_attention_broadcast_state = PyramidAttentionBroadcastState(skip_callback=skip_callback) + hook = PyramidAttentionBroadcastHook() add_hook_to_module(module, hook, append=True) def _apply_pyramid_attention_broadcast_on_attention_class( pipeline: DiffusionPipeline, name: str, module: Attention, config: PyramidAttentionBroadcastConfig -): - # Similar check as PEFT to determine if a string layer name matches a module name - # TODO(aryan): make this regex based +) -> bool: is_spatial_self_attention = ( - any( - f"{identifier}." in name or identifier == name for identifier in config.spatial_attention_block_identifiers - ) + any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers) and config.spatial_attention_block_skip_range is not None and not module.is_cross_attention ) is_temporal_self_attention = ( - any( - f"{identifier}." in name or identifier == name - for identifier in config.temporal_attention_block_identifiers - ) + any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers) and config.temporal_attention_block_skip_range is not None and not module.is_cross_attention ) is_cross_attention = ( - any(f"{identifier}." in name or identifier == name for identifier in config.cross_attention_block_identifiers) + any(re.search(identifier, name) is not None for identifier in config.cross_attention_block_identifiers) and config.cross_attention_block_skip_range is not None and module.is_cross_attention ) @@ -239,7 +242,7 @@ def _apply_pyramid_attention_broadcast_on_attention_class( f"block identifiers in the configuration or use the specialized `apply_pyramid_attention_broadcast_on_module` " f"function to apply PAB to this layer." ) - return + return False def skip_callback(module: nn.Module) -> bool: pab_state = module._pyramid_attention_broadcast_state @@ -247,35 +250,39 @@ def skip_callback(module: nn.Module) -> bool: return False is_within_timestep_range = timestep_skip_range[0] < pipeline._current_timestep < timestep_skip_range[1] + if not is_within_timestep_range: + # We are still not in the phase of inference where skipping attention is possible without minimal quality + # loss, as described in the paper. So, the attention computation cannot be skipped + return False - if is_within_timestep_range: - should_compute_attention = pab_state.iteration > 0 and pab_state.iteration % block_skip_range == 0 - return not should_compute_attention - - # We are still not in the phase of inference where skipping attention is possible without minimal quality - # loss, as described in the paper. So, the attention computation cannot be skipped - return False + should_compute_attention = pab_state.iteration > 0 and pab_state.iteration % block_skip_range == 0 + return not should_compute_attention logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}") apply_pyramid_attention_broadcast_on_module(module, skip_callback) + return True class PyramidAttentionBroadcastHook(ModelHook): - def __init__(self, skip_callback: Callable[[nn.Module], bool]) -> None: - super().__init__() + r"""A hook that applies Pyramid Attention Broadcast to a given module.""" - self.skip_callback = skip_callback + _is_stateful = True + + def __init__(self) -> None: + super().__init__() def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + state: PyramidAttentionBroadcastState = module._pyramid_attention_broadcast_state - if self.skip_callback(module): + if state.skip_callback(module): output = module._pyramid_attention_broadcast_state.cache else: output = module._old_forward(*args, **kwargs) + state.cache = output + state.iteration += 1 return module._diffusers_hook.post_forward(module, output) - def post_forward(self, module: nn.Module, output: Any) -> Any: - module._pyramid_attention_broadcast_state.update_state(output) - return output + def reset_state(self, module: nn.Module) -> None: + module._pyramid_attention_broadcast_state.reset() From bb250d60e5d63a3c1c298fce6613bb066e1de351 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 30 Dec 2024 16:04:32 +0100 Subject: [PATCH 26/61] update --- src/diffusers/models/attention_processor.py | 2 ++ .../pyramid_attention_broadcast_utils.py | 15 ++++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4d7ae6bef26e..c91ddd95c861 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -930,6 +930,8 @@ def __init__( self.out_dim = out_dim if out_dim is not None else query_dim self.out_context_dim = out_context_dim if out_context_dim else query_dim self.context_pre_only = context_pre_only + # TODO(aryan): Maybe try to improve the checks in PAB instead + self.is_cross_attention = False self.heads = out_dim // dim_head if out_dim is not None else heads diff --git a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py index bf08c69f130a..80a6b57637f3 100644 --- a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py @@ -18,7 +18,7 @@ import torch.nn as nn -from ..models.attention_processor import Attention +from ..models.attention_processor import Attention, MochiAttention from ..models.hooks import ModelHook, add_hook_to_module from ..utils import logging from .pipeline_utils import DiffusionPipeline @@ -27,7 +27,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -_ATTENTION_CLASSES = (Attention,) +_ATTENTION_CLASSES = (Attention, MochiAttention) _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks") _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) @@ -175,8 +175,10 @@ def apply_pyramid_attention_broadcast( for name, module in denoiser.named_modules(): if not isinstance(module, _ATTENTION_CLASSES): continue - if isinstance(module, Attention): + if isinstance(module, (Attention)): _apply_pyramid_attention_broadcast_on_attention_class(pipeline, name, module, config) + if isinstance(module, MochiAttention): + _apply_pyramid_attention_broadcast_on_mochi_attention_class(pipeline, name, module, config) def apply_pyramid_attention_broadcast_on_module( @@ -263,6 +265,13 @@ def skip_callback(module: nn.Module) -> bool: return True +def _apply_pyramid_attention_broadcast_on_mochi_attention_class( + pipeline: DiffusionPipeline, name: str, module: MochiAttention, config: PyramidAttentionBroadcastConfig +) -> bool: + # The same logic as Attention class works here, so just use that for now + return _apply_pyramid_attention_broadcast_on_attention_class(pipeline, name, module, config) + + class PyramidAttentionBroadcastHook(ModelHook): r"""A hook that applies Pyramid Attention Broadcast to a given module.""" From cbc086f1fffc7f968c91586ab174bf623ef7a841 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 30 Dec 2024 20:49:01 +0100 Subject: [PATCH 27/61] update --- src/diffusers/__init__.py | 6 ++++++ .../pipelines/pyramid_attention_broadcast_utils.py | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5e9ab2a117d1..cdc934fe76ca 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -336,6 +336,7 @@ "PixArtAlphaPipeline", "PixArtSigmaPAGPipeline", "PixArtSigmaPipeline", + "PyramidAttentionBroadcastConfig", "ReduxImageEncoder", "SanaPAGPipeline", "SanaPipeline", @@ -422,6 +423,8 @@ "WuerstchenCombinedPipeline", "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", + "apply_pyramid_attention_broadcast", + "apply_pyramid_attention_broadcast_on_module", ] ) @@ -825,6 +828,7 @@ PixArtAlphaPipeline, PixArtSigmaPAGPipeline, PixArtSigmaPipeline, + PyramidAttentionBroadcastConfig, ReduxImageEncoder, SanaPAGPipeline, SanaPipeline, @@ -909,6 +913,8 @@ WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, + apply_pyramid_attention_broadcast, + apply_pyramid_attention_broadcast_on_module, ) try: diff --git a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py index 80a6b57637f3..e9689361a629 100644 --- a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py @@ -143,6 +143,7 @@ def apply_pyramid_attention_broadcast( Example: ```python + >>> import torch >>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) @@ -169,6 +170,10 @@ def apply_pyramid_attention_broadcast( ) config.spatial_attention_block_skip_range = 2 + # Note: For diffusers models, we know that it will be either a transformer or a unet, and we also follow + # the naming convention. The option to specify a denoiser is provided for flexibility when this function + # is used outside of the diffusers library, say for a ComfyUI custom model. In that case, the user can + # specify pipeline as None but provide the denoiser module. if denoiser is None: denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet From 7debcec2280852ec115bcef7c114dd3017e530da Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 30 Dec 2024 20:49:08 +0100 Subject: [PATCH 28/61] revert changes to tests --- tests/pipelines/cogvideo/test_cogvideox.py | 37 +-------------- .../cogvideo/test_cogvideox_image2video.py | 37 +-------------- .../cogvideo/test_cogvideox_video2video.py | 37 +-------------- tests/pipelines/latte/test_latte.py | 45 +++---------------- 4 files changed, 13 insertions(+), 143 deletions(-) diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index 495341b69b77..884ddfb2a95a 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -21,10 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler -from diffusers.pipelines.pyramid_attention_broadcast_utils import ( - PyramidAttentionBroadcastConfig, - apply_pyramid_attention_broadcast, -) from diffusers.utils.testing_utils import ( enable_full_determinism, numpy_cosine_similarity_distance, @@ -63,7 +59,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False - def get_dummy_components(self, num_layers: int = 1): + def get_dummy_components(self): torch.manual_seed(0) transformer = CogVideoXTransformer3DModel( # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings @@ -75,7 +71,7 @@ def get_dummy_components(self, num_layers: int = 1): out_channels=4, time_embed_dim=2, text_embed_dim=32, # Must match with tiny-random-t5 - num_layers=num_layers, + num_layers=1, sample_width=2, # latent width: 2 -> final width: 16 sample_height=2, # latent height: 2 -> final height: 16 sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 @@ -323,35 +319,6 @@ def test_fused_qkv_projections(self): original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 ), "Original outputs should match when fused QKV projections are disabled." - def test_pyramid_attention_broadcast(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - num_layers = 4 - components = self.get_dummy_components(num_layers=num_layers) - pipe = self.pipeline_class(**components) - pipe = pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - inputs["num_inference_steps"] = 4 - frames = pipe(**inputs).frames # [B, F, C, H, W] - original_image_slice = frames[0, -2:, -1, -3:, -3:] - - config = PyramidAttentionBroadcastConfig( - spatial_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(100, 800) - ) - apply_pyramid_attention_broadcast(pipe, config) - - inputs = self.get_dummy_inputs(device) - inputs["num_inference_steps"] = 4 - frames = pipe(**inputs).frames - image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:] - - # We need to use higher tolerance because we are using a random model. With a converged/trained - # model, the tolerance can be lower. - assert np.allclose( - original_image_slice, image_slice_pab_enabled, atol=0.2 - ), "PAB outputs should not differ much in specified timestep range." - @slow @require_torch_gpu diff --git a/tests/pipelines/cogvideo/test_cogvideox_image2video.py b/tests/pipelines/cogvideo/test_cogvideox_image2video.py index 03422e586156..f7e1fe7fd6c7 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_image2video.py +++ b/tests/pipelines/cogvideo/test_cogvideox_image2video.py @@ -22,10 +22,6 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler -from diffusers.pipelines.pyramid_attention_broadcast_utils import ( - PyramidAttentionBroadcastConfig, - apply_pyramid_attention_broadcast, -) from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -65,7 +61,7 @@ class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC ) test_xformers_attention = False - def get_dummy_components(self, num_layers: int = 1): + def get_dummy_components(self): torch.manual_seed(0) transformer = CogVideoXTransformer3DModel( # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings @@ -80,7 +76,7 @@ def get_dummy_components(self, num_layers: int = 1): out_channels=4, time_embed_dim=2, text_embed_dim=32, # Must match with tiny-random-t5 - num_layers=num_layers, + num_layers=1, sample_width=2, # latent width: 2 -> final width: 16 sample_height=2, # latent height: 2 -> final height: 16 sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 @@ -346,35 +342,6 @@ def test_fused_qkv_projections(self): original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 ), "Original outputs should match when fused QKV projections are disabled." - def test_pyramid_attention_broadcast(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - num_layers = 4 - components = self.get_dummy_components(num_layers=num_layers) - pipe = self.pipeline_class(**components) - pipe = pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - inputs["num_inference_steps"] = 4 - frames = pipe(**inputs).frames # [B, F, C, H, W] - original_image_slice = frames[0, -2:, -1, -3:, -3:] - - config = PyramidAttentionBroadcastConfig( - spatial_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(100, 800) - ) - apply_pyramid_attention_broadcast(pipe, config) - - inputs = self.get_dummy_inputs(device) - inputs["num_inference_steps"] = 4 - frames = pipe(**inputs).frames - image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:] - - # We need to use higher tolerance because we are using a random model. With a converged/trained - # model, the tolerance can be lower. - assert np.allclose( - original_image_slice, image_slice_pab_enabled, atol=0.2 - ), "PAB outputs should not differ much in specified timestep range." - @slow @require_torch_gpu diff --git a/tests/pipelines/cogvideo/test_cogvideox_video2video.py b/tests/pipelines/cogvideo/test_cogvideox_video2video.py index bfe1bc835c4d..4d836cb5e2a4 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_video2video.py +++ b/tests/pipelines/cogvideo/test_cogvideox_video2video.py @@ -21,10 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXVideoToVideoPipeline, DDIMScheduler -from diffusers.pipelines.pyramid_attention_broadcast_utils import ( - PyramidAttentionBroadcastConfig, - apply_pyramid_attention_broadcast, -) from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS @@ -57,7 +53,7 @@ class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC ) test_xformers_attention = False - def get_dummy_components(self, num_layers: int = 1): + def get_dummy_components(self): torch.manual_seed(0) transformer = CogVideoXTransformer3DModel( # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings @@ -69,7 +65,7 @@ def get_dummy_components(self, num_layers: int = 1): out_channels=4, time_embed_dim=2, text_embed_dim=32, # Must match with tiny-random-t5 - num_layers=num_layers, + num_layers=1, sample_width=2, # latent width: 2 -> final width: 16 sample_height=2, # latent height: 2 -> final height: 16 sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 @@ -327,32 +323,3 @@ def test_fused_qkv_projections(self): assert np.allclose( original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 ), "Original outputs should match when fused QKV projections are disabled." - - def test_pyramid_attention_broadcast(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - num_layers = 4 - components = self.get_dummy_components(num_layers=num_layers) - pipe = self.pipeline_class(**components) - pipe = pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - inputs["num_inference_steps"] = 4 - frames = pipe(**inputs).frames # [B, F, C, H, W] - original_image_slice = frames[0, -2:, -1, -3:, -3:] - - config = PyramidAttentionBroadcastConfig( - spatial_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(100, 800) - ) - apply_pyramid_attention_broadcast(pipe, config) - - inputs = self.get_dummy_inputs(device) - inputs["num_inference_steps"] = 4 - frames = pipe(**inputs).frames - image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:] - - # We need to use higher tolerance because we are using a random model. With a converged/trained - # model, the tolerance can be lower. - assert np.allclose( - original_image_slice, image_slice_pab_enabled, atol=0.2 - ), "PAB outputs should not differ much in specified timestep range." diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index fb08f468a4a3..9667ebff249d 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -22,10 +22,11 @@ import torch from transformers import AutoTokenizer, T5EncoderModel -from diffusers import AutoencoderKL, DDIMScheduler, LattePipeline, LatteTransformer3DModel -from diffusers.pipelines.pyramid_attention_broadcast_utils import ( - PyramidAttentionBroadcastConfig, - apply_pyramid_attention_broadcast, +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + LattePipeline, + LatteTransformer3DModel, ) from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( @@ -52,11 +53,11 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase): required_optional_params = PipelineTesterMixin.required_optional_params - def get_dummy_components(self, num_layers: int = 1): + def get_dummy_components(self): torch.manual_seed(0) transformer = LatteTransformer3DModel( sample_size=8, - num_layers=num_layers, + num_layers=1, patch_size=2, attention_head_dim=8, num_attention_heads=3, @@ -263,38 +264,6 @@ def test_save_load_optional_components(self): def test_xformers_attention_forwardGenerator_pass(self): super()._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False) - def test_pyramid_attention_broadcast(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - num_layers = 4 - components = self.get_dummy_components(num_layers=num_layers) - pipe = self.pipeline_class(**components) - pipe = pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - inputs["num_inference_steps"] = 4 - frames = pipe(**inputs).frames # [B, F, C, H, W] - original_image_slice = frames[0, -2:, -1, -3:, -3:] - - config = PyramidAttentionBroadcastConfig( - spatial_attention_block_skip_range=2, - temporal_attention_block_skip_range=3, - spatial_attention_timestep_skip_range=(100, 800), - temporal_attention_timestep_skip_range=(100, 800), - ) - apply_pyramid_attention_broadcast(pipe, config) - - inputs = self.get_dummy_inputs(device) - inputs["num_inference_steps"] = 4 - frames = pipe(**inputs).frames - image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:] - - # We need to use higher tolerance because we are using a random model. With a converged/trained - # model, the tolerance can be lower. - assert np.allclose( - original_image_slice, image_slice_pab_enabled, atol=0.2 - ), "PAB outputs should not differ much in specified timestep range." - @slow @require_torch_gpu From a5c34afccdbdea4f3b3533ee7e64be7ffb7dcc0a Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 30 Dec 2024 21:24:18 +0100 Subject: [PATCH 29/61] update docs --- docs/source/en/_toctree.yml | 2 ++ docs/source/en/api/cache.md | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 docs/source/en/api/cache.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 134a127d4320..fae25c72958c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -590,6 +590,8 @@ title: Attention Processor - local: api/activations title: Custom activation functions + - local: api/cache + title: Caching techniques - local: api/normalization title: Custom normalization layers - local: api/utilities diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md new file mode 100644 index 000000000000..74ee343f8d46 --- /dev/null +++ b/docs/source/en/api/cache.md @@ -0,0 +1,28 @@ + + +# Caching Methods + +## Pyramid Attention Broadcast + +[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You. + +Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation. + +Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request. + +## PyramidAttentionBroadcastConfig + +[[autodoc]] PyramidAttentionBroadcastConfig + +[[autodoc]] apply_pyramid_attention_broadcast + +[[autodoc]] apply_pyramid_attention_broadcast_on_module From bbcde6b09f1aa3f1cfb5ac5006589df84d3e0ac1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 30 Dec 2024 22:43:43 +0100 Subject: [PATCH 30/61] update tests --- .../dummy_torch_and_transformers_objects.py | 23 ++++ tests/pipelines/allegro/test_allegro.py | 8 +- tests/pipelines/cogvideo/test_cogvideox.py | 7 +- tests/pipelines/flux/test_pipeline_flux.py | 9 +- .../hunyuan_video/test_hunyuan_video.py | 8 +- tests/pipelines/latte/test_latte.py | 21 +++- tests/pipelines/test_pipelines_common.py | 106 ++++++++++++++++++ 7 files changed, 164 insertions(+), 18 deletions(-) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 9b36be9e0604..9876af2bafe9 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1277,6 +1277,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class PyramidAttentionBroadcastConfig(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class ReduxImageEncoder(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -2535,3 +2550,11 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) + + +def apply_pyramid_attention_broadcast(*args, **kwargs): + requires_backends(apply_pyramid_attention_broadcast, ["torch", "transformers"]) + + +def apply_pyramid_attention_broadcast_on_module(*args, **kwargs): + requires_backends(apply_pyramid_attention_broadcast_on_module, ["torch", "transformers"]) diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py index d09fc0488378..726ab5f40221 100644 --- a/tests/pipelines/allegro/test_allegro.py +++ b/tests/pipelines/allegro/test_allegro.py @@ -30,13 +30,13 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np enable_full_determinism() -class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class AllegroPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): pipeline_class = AllegroPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -54,14 +54,14 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = AllegroTransformer3DModel( num_attention_heads=2, attention_head_dim=12, in_channels=4, out_channels=4, - num_layers=1, + num_layers=num_layers, cross_attention_dim=24, sample_width=8, sample_height=8, diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index 884ddfb2a95a..c7f5af8c9a88 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -32,6 +32,7 @@ from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import ( PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, check_qkv_fusion_matches_attn_procs_length, check_qkv_fusion_processors_exist, to_np, @@ -41,7 +42,7 @@ enable_full_determinism() -class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): pipeline_class = CogVideoXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -59,7 +60,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = CogVideoXTransformer3DModel( # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings @@ -71,7 +72,7 @@ def get_dummy_components(self): out_channels=4, time_embed_dim=2, text_embed_dim=32, # Must match with tiny-random-t5 - num_layers=1, + num_layers=num_layers, sample_width=2, # latent width: 2 -> final width: 16 sample_height=2, # latent height: 2 -> final height: 16 sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 7981e6c2a93b..11009e4cae97 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -18,12 +18,15 @@ from ..test_pipelines_common import ( FluxIPAdapterTesterMixin, PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, check_qkv_fusion_matches_attn_procs_length, check_qkv_fusion_processors_exist, ) -class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin): +class FluxPipelineFastTests( + unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin +): pipeline_class = FluxPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) @@ -31,12 +34,12 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapte # there is no xformers processor for Flux test_xformers_attention = False - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = FluxTransformer2DModel( patch_size=1, in_channels=4, - num_layers=1, + num_layers=num_layers, num_single_layers=1, attention_head_dim=16, num_attention_heads=2, diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index 567002268106..ee3fe34bbde0 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -30,13 +30,13 @@ torch_device, ) -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np enable_full_determinism() -class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): pipeline_class = HunyuanVideoPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) @@ -54,14 +54,14 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): # there is no xformers processor for Flux test_xformers_attention = False - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = HunyuanVideoTransformer3DModel( in_channels=4, out_channels=4, num_attention_heads=2, attention_head_dim=10, - num_layers=1, + num_layers=num_layers, num_single_layers=1, num_refiner_layers=1, patch_size=1, diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index 9667ebff249d..ef979cb252ca 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -27,6 +27,7 @@ DDIMScheduler, LattePipeline, LatteTransformer3DModel, + PyramidAttentionBroadcastConfig, ) from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( @@ -38,13 +39,13 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np enable_full_determinism() -class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): pipeline_class = LattePipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -53,11 +54,23 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase): required_optional_params = PipelineTesterMixin.required_optional_params - def get_dummy_components(self): + pab_config = PyramidAttentionBroadcastConfig( + spatial_attention_block_skip_range=2, + temporal_attention_block_skip_range=2, + cross_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(100, 700), + temporal_attention_timestep_skip_range=(100, 800), + cross_attention_timestep_skip_range=(100, 800), + spatial_attention_block_identifiers=["transformer_blocks"], + temporal_attention_block_identifiers=["temporal_transformer_blocks"], + cross_attention_block_identifiers=["transformer_blocks"], + ) + + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = LatteTransformer3DModel( sample_size=8, - num_layers=1, + num_layers=num_layers, patch_size=2, attention_head_dim=8, num_attention_heads=3, diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 764be1890cc5..c8ba038b5b3b 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -24,9 +24,11 @@ DDIMScheduler, DiffusionPipeline, KolorsPipeline, + PyramidAttentionBroadcastConfig, StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel, + apply_pyramid_attention_broadcast, ) from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin @@ -36,6 +38,7 @@ from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet from diffusers.models.unets.unet_motion_model import UNetMotionModel from diffusers.pipelines.pipeline_utils import StableDiffusionMixin +from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastHook from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available @@ -2271,6 +2274,109 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4): self.assertLess(max_diff, expected_max_difference) +class PyramidAttentionBroadcastTesterMixin: + pab_config = PyramidAttentionBroadcastConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(100, 800), + spatial_attention_block_identifiers=["transformer_blocks"], + ) + + def test_pyramid_attention_broadcast_layers(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + num_layers = 2 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + apply_pyramid_attention_broadcast(pipe, self.pab_config) + + expected_hooks = 0 + if self.pab_config.spatial_attention_block_skip_range is not None: + expected_hooks += num_layers + if self.pab_config.temporal_attention_block_skip_range is not None: + expected_hooks += num_layers + if self.pab_config.cross_attention_block_skip_range is not None: + expected_hooks += num_layers + + denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet + count = 0 + for module in denoiser.modules(): + if hasattr(module, "_diffusers_hook"): + count += 1 + self.assertTrue( + isinstance(module._diffusers_hook, PyramidAttentionBroadcastHook), + "Hook should be of type PyramidAttentionBroadcastHook.", + ) + self.assertTrue( + hasattr(module, "_pyramid_attention_broadcast_state"), + "PAB state should be initialized when enabled.", + ) + self.assertTrue( + module._pyramid_attention_broadcast_state.cache is None, "Cache should be None at initialization." + ) + self.assertEqual(count, expected_hooks, "Number of hooks should match the expected number.") + + # Perform dummy inference step to ensure state is updated + def pab_state_check_callback(pipe, i, t, kwargs): + for module in denoiser.modules(): + if hasattr(module, "_diffusers_hook"): + self.assertTrue( + module._pyramid_attention_broadcast_state.cache is not None, + "Cache should have updated during inference.", + ) + self.assertTrue( + module._pyramid_attention_broadcast_state.iteration == i + 1, + "Hook iteration state should have updated during inference.", + ) + return {} + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 2 + inputs["callback_on_step_end"] = pab_state_check_callback + pipe(**inputs)[0] + + # After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states + for module in denoiser.modules(): + if hasattr(module, "_diffusers_hook"): + self.assertTrue( + module._pyramid_attention_broadcast_state.cache is None, + "Cache should be reset to None after inference.", + ) + self.assertTrue( + module._pyramid_attention_broadcast_state.iteration == 0, + "Iteration should be reset to 0 after inference.", + ) + + def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2): + # We need to use higher tolerance because we are using a random model. With a converged/trained + # model, the tolerance can be lower. + + device = "cpu" # ensure determinism for the device-dependent torch.Generator + num_layers = 2 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + output = pipe(**inputs)[0] + original_image_slice = output.flatten() + original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:])) + + apply_pyramid_attention_broadcast(pipe, self.pab_config) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + output = pipe(**inputs)[0] + image_slice_pab_enabled = output.flatten() + image_slice_pab_enabled = np.concatenate((image_slice_pab_enabled[:8], image_slice_pab_enabled[-8:])) + + assert np.allclose( + original_image_slice, image_slice_pab_enabled, atol=expected_atol + ), "PAB outputs should not differ much in specified timestep range." + + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image. From b148ab4b2638bdf3dd15b15652c80cc387534f74 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 31 Dec 2024 20:02:06 +0530 Subject: [PATCH 31/61] Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 +- docs/source/en/api/cache.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index fae25c72958c..b7240f2ed41c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -591,7 +591,7 @@ - local: api/activations title: Custom activation functions - local: api/cache - title: Caching techniques + title: Caching methods - local: api/normalization title: Custom normalization layers - local: api/utilities diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md index 74ee343f8d46..5a747bec64cf 100644 --- a/docs/source/en/api/cache.md +++ b/docs/source/en/api/cache.md @@ -9,7 +9,7 @@ Unless required by applicable law or agreed to in writing, software distributed an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --> -# Caching Methods +# Caching methods ## Pyramid Attention Broadcast @@ -19,7 +19,7 @@ Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffus Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request. -## PyramidAttentionBroadcastConfig +### PyramidAttentionBroadcastConfig [[autodoc]] PyramidAttentionBroadcastConfig From d4ecd6c95bca6aab832255370828eb57afae3a4b Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 31 Dec 2024 15:38:19 +0100 Subject: [PATCH 32/61] update --- .../pipelines/allegro/pipeline_allegro.py | 2 ++ .../hunyuan_video/test_hunyuan_video.py | 4 ++-- tests/pipelines/test_pipelines_common.py | 21 ++++++++++++++----- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 8af6f9e0bc31..2feadf18f67a 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -26,6 +26,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro from ...models.embeddings import get_3d_rotary_pos_embed_allegro +from ...models.hooks import reset_stateful_hooks from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -935,6 +936,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (video,) diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index ee3fe34bbde0..991c611e5f16 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -54,7 +54,7 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadca # there is no xformers processor for Flux test_xformers_attention = False - def get_dummy_components(self, num_layers: int = 1): + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): torch.manual_seed(0) transformer = HunyuanVideoTransformer3DModel( in_channels=4, @@ -62,7 +62,7 @@ def get_dummy_components(self, num_layers: int = 1): num_attention_heads=2, attention_head_dim=10, num_layers=num_layers, - num_single_layers=1, + num_single_layers=num_single_layers, num_refiner_layers=1, patch_size=1, patch_size_t=1, diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index c8ba038b5b3b..0af5bf2f5a73 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2283,8 +2283,19 @@ class PyramidAttentionBroadcastTesterMixin: def test_pyramid_attention_broadcast_layers(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator - num_layers = 2 - components = self.get_dummy_components(num_layers=num_layers) + + num_layers = 0 + num_single_layers = 0 + dummy_component_kwargs = {} + dummy_component_parameters = inspect.signature(self.get_dummy_components).parameters + if "num_layers" in dummy_component_parameters: + num_layers = 2 + dummy_component_kwargs["num_layers"] = num_layers + if "num_single_layers" in dummy_component_parameters: + num_single_layers = 2 + dummy_component_kwargs["num_single_layers"] = num_single_layers + + components = self.get_dummy_components(**dummy_component_kwargs) pipe = self.pipeline_class(**components) pipe.set_progress_bar_config(disable=None) @@ -2292,11 +2303,11 @@ def test_pyramid_attention_broadcast_layers(self): expected_hooks = 0 if self.pab_config.spatial_attention_block_skip_range is not None: - expected_hooks += num_layers + expected_hooks += num_layers + num_single_layers if self.pab_config.temporal_attention_block_skip_range is not None: - expected_hooks += num_layers + expected_hooks += num_layers + num_single_layers if self.pab_config.cross_attention_block_skip_range is not None: - expected_hooks += num_layers + expected_hooks += num_layers + num_single_layers denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet count = 0 From 6cca58fdff6ba6595a05e75c52ef08be125dd172 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 31 Dec 2024 21:06:29 +0100 Subject: [PATCH 33/61] fix flux test --- tests/pipelines/flux/test_pipeline_flux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 11009e4cae97..0310f8449207 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -34,13 +34,13 @@ class FluxPipelineFastTests( # there is no xformers processor for Flux test_xformers_attention = False - def get_dummy_components(self, num_layers: int = 1): + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): torch.manual_seed(0) transformer = FluxTransformer2DModel( patch_size=1, in_channels=4, num_layers=num_layers, - num_single_layers=1, + num_single_layers=num_single_layers, attention_head_dim=16, num_attention_heads=2, joint_attention_dim=32, From d9fad00f2e7bea44a22f3c56de9231c151a1f603 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 2 Jan 2025 15:49:07 +0100 Subject: [PATCH 34/61] reorder --- .../pyramid_attention_broadcast_utils.py | 50 +++++++++---------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py index e9689361a629..3e2e4a7cb8ee 100644 --- a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py +++ b/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py @@ -117,6 +117,31 @@ def reset(self): self.cache = None +class PyramidAttentionBroadcastHook(ModelHook): + r"""A hook that applies Pyramid Attention Broadcast to a given module.""" + + _is_stateful = True + + def __init__(self) -> None: + super().__init__() + + def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + state: PyramidAttentionBroadcastState = module._pyramid_attention_broadcast_state + + if state.skip_callback(module): + output = module._pyramid_attention_broadcast_state.cache + else: + output = module._old_forward(*args, **kwargs) + + state.cache = output + state.iteration += 1 + return module._diffusers_hook.post_forward(module, output) + + def reset_state(self, module: nn.Module) -> None: + module._pyramid_attention_broadcast_state.reset() + + def apply_pyramid_attention_broadcast( pipeline: DiffusionPipeline, config: Optional[PyramidAttentionBroadcastConfig] = None, @@ -275,28 +300,3 @@ def _apply_pyramid_attention_broadcast_on_mochi_attention_class( ) -> bool: # The same logic as Attention class works here, so just use that for now return _apply_pyramid_attention_broadcast_on_attention_class(pipeline, name, module, config) - - -class PyramidAttentionBroadcastHook(ModelHook): - r"""A hook that applies Pyramid Attention Broadcast to a given module.""" - - _is_stateful = True - - def __init__(self) -> None: - super().__init__() - - def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: - args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) - state: PyramidAttentionBroadcastState = module._pyramid_attention_broadcast_state - - if state.skip_callback(module): - output = module._pyramid_attention_broadcast_state.cache - else: - output = module._old_forward(*args, **kwargs) - - state.cache = output - state.iteration += 1 - return module._diffusers_hook.post_forward(module, output) - - def reset_state(self, module: nn.Module) -> None: - module._pyramid_attention_broadcast_state.reset() From 2436b3fb0db0e4a5503ce00058f4529633bcd473 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 13 Jan 2025 06:45:48 +0100 Subject: [PATCH 35/61] refactor --- src/diffusers/__init__.py | 13 +- src/diffusers/hooks/__init__.py | 5 + src/diffusers/hooks/hooks.py | 162 +++++++++++++ .../pyramid_attention_broadcast.py} | 140 +++++------ src/diffusers/models/hooks.py | 229 ------------------ src/diffusers/pipelines/__init__.py | 10 - 6 files changed, 237 insertions(+), 322 deletions(-) create mode 100644 src/diffusers/hooks/__init__.py create mode 100644 src/diffusers/hooks/hooks.py rename src/diffusers/{pipelines/pyramid_attention_broadcast_utils.py => hooks/pyramid_attention_broadcast.py} (77%) delete mode 100644 src/diffusers/models/hooks.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index cdc934fe76ca..131fa1cc8821 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -75,6 +75,12 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: + _import_structure["hooks"].extend( + [ + "PyramidAttentionBroadcastConfig", + "apply_pyramid_attention_broadcast", + ] + ) _import_structure["models"].extend( [ "AllegroTransformer3DModel", @@ -336,7 +342,6 @@ "PixArtAlphaPipeline", "PixArtSigmaPAGPipeline", "PixArtSigmaPipeline", - "PyramidAttentionBroadcastConfig", "ReduxImageEncoder", "SanaPAGPipeline", "SanaPipeline", @@ -423,8 +428,6 @@ "WuerstchenCombinedPipeline", "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", - "apply_pyramid_attention_broadcast", - "apply_pyramid_attention_broadcast_on_module", ] ) @@ -589,6 +592,7 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: + from .hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .models import ( AllegroTransformer3DModel, AsymmetricAutoencoderKL, @@ -828,7 +832,6 @@ PixArtAlphaPipeline, PixArtSigmaPAGPipeline, PixArtSigmaPipeline, - PyramidAttentionBroadcastConfig, ReduxImageEncoder, SanaPAGPipeline, SanaPipeline, @@ -913,8 +916,6 @@ WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, - apply_pyramid_attention_broadcast, - apply_pyramid_attention_broadcast_on_module, ) try: diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py new file mode 100644 index 000000000000..148e6c8fdc97 --- /dev/null +++ b/src/diffusers/hooks/__init__.py @@ -0,0 +1,5 @@ +from ..utils import is_torch_available + + +if is_torch_available(): + from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py new file mode 100644 index 000000000000..d9f257aac6c6 --- /dev/null +++ b/src/diffusers/hooks/hooks.py @@ -0,0 +1,162 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Any, Dict, Tuple + +import torch + +from ..utils.logging import get_logger + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class ModelHook: + r""" + A hook that contains callbacks to be executed just before and after the forward method of a model. + """ + + _is_stateful = False + + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is initialized. + + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + return module + + def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is deinitalized. + + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + module.forward = module._old_forward + del module._old_forward + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: + r""" + Hook that is executed just before the forward method of the model. + + Args: + module (`torch.nn.Module`): + The module whose forward pass will be executed just after this event. + args (`Tuple[Any]`): + The positional arguments passed to the module. + kwargs (`Dict[Str, Any]`): + The keyword arguments passed to the module. + Returns: + `Tuple[Tuple[Any], Dict[Str, Any]]`: + A tuple with the treated `args` and `kwargs`. + """ + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: + r""" + Hook that is executed just after the forward method of the model. + + Args: + module (`torch.nn.Module`): + The module whose forward pass been executed just before this event. + output (`Any`): + The output of the module. + Returns: + `Any`: The processed `output`. + """ + return output + + def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when the hook is detached from a module. + + Args: + module (`torch.nn.Module`): + The module detached from this hook. + """ + return module + + def reset_state(self, module: torch.nn.Module): + if self._is_stateful: + raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") + return module + + +class HookRegistry: + def __init__(self, module_ref: torch.nn.Module) -> None: + super().__init__() + + self.hooks: Dict[str, ModelHook] = {} + + self._module_ref = module_ref + self._hook_order = [] + + def register_hook(self, hook: ModelHook, name: str) -> None: + if name in self.hooks.keys(): + logger.warning(f"Hook with name {name} already exists, replacing it.") + + if hasattr(self._module_ref, "_old_forward"): + old_forward = self._module_ref._old_forward + else: + old_forward = self._module_ref.forward + self._module_ref._old_forward = self._module_ref.forward + + self._module_ref = hook.initialize_hook(self._module_ref) + + if hasattr(hook, "new_forward"): + new_forward = hook.new_forward + else: + + def new_forward(module, *args, **kwargs): + args, kwargs = hook.pre_forward(module, *args, **kwargs) + output = old_forward(*args, **kwargs) + return hook.post_forward(module, output) + + new_forward = functools.update_wrapper(new_forward, old_forward) + self._module_ref.forward = new_forward.__get__(self._module_ref) + + self.hooks[name] = hook + self._hook_order.append(name) + + def get_hook(self, name: str) -> ModelHook: + if name not in self.hooks.keys(): + raise ValueError(f"Hook with name {name} not found.") + return self.hooks[name] + + def remove_hook(self, name: str) -> None: + if name not in self.hooks.keys(): + raise ValueError(f"Hook with name {name} not found.") + self.hooks[name].deinitalize_hook(self._module_ref) + del self.hooks[name] + self._hook_order.remove(name) + + @classmethod + def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry": + if not hasattr(module, "_diffusers_hook"): + module._diffusers_hook = cls(module) + return module._diffusers_hook + + def __repr__(self) -> str: + hook_repr = "" + for i, hook_name in enumerate(self._hook_order): + hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" + if i < len(self._hook_order) - 1: + hook_repr += "\n" + return f"HookRegistry(\n{hook_repr}\n)" diff --git a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py b/src/diffusers/hooks/pyramid_attention_broadcast.py similarity index 77% rename from src/diffusers/pipelines/pyramid_attention_broadcast_utils.py rename to src/diffusers/hooks/pyramid_attention_broadcast.py index 3e2e4a7cb8ee..3f0080165927 100644 --- a/src/diffusers/pipelines/pyramid_attention_broadcast_utils.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -14,14 +14,13 @@ import re from dataclasses import dataclass -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union -import torch.nn as nn +import torch from ..models.attention_processor import Attention, MochiAttention -from ..models.hooks import ModelHook, add_hook_to_module from ..utils import logging -from .pipeline_utils import DiffusionPipeline +from .hooks import HookRegistry, ModelHook logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -81,6 +80,8 @@ class PyramidAttentionBroadcastConfig: temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS + current_timestep_callback: Callable[[], int] = None + # TODO(aryan): add PAB for MLP layers (very limited speedup from testing with original codebase # so not added for now) @@ -89,14 +90,6 @@ class PyramidAttentionBroadcastState: r""" State for Pyramid Attention Broadcast. - Args: - skip_callback (`Callable[[nn.Module], bool]`): - A callback function that determines whether the attention computation should be skipped or not. The - callback function should return a boolean value, where `True` indicates that the attention computation - should be skipped, and `False` indicates that the attention computation should not be skipped. The callback - function will receive a torch.nn.Module containing a `_pyramid_attention_broadcast_state` attribute that - can should be used to retrieve and update the state of PAB for the given module. - Attributes: iteration (`int`): The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is @@ -106,9 +99,7 @@ class PyramidAttentionBroadcastState: attention computation is skipped. It is either a tensor or a tuple of tensors, depending on the module. """ - def __init__(self, skip_callback: Callable[[nn.Module], bool]) -> None: - self.skip_callback = skip_callback - + def __init__(self) -> None: self.iteration = 0 self.cache = None @@ -122,30 +113,33 @@ class PyramidAttentionBroadcastHook(ModelHook): _is_stateful = True - def __init__(self) -> None: + def __init__(self, skip_callback: Callable[[torch.nn.Module], bool]) -> None: super().__init__() - def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: + self.skip_callback = skip_callback + + def initialize_hook(self, module): + self.state = PyramidAttentionBroadcastState() + + def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) - state: PyramidAttentionBroadcastState = module._pyramid_attention_broadcast_state - if state.skip_callback(module): + if self.skip_callback(module): output = module._pyramid_attention_broadcast_state.cache else: output = module._old_forward(*args, **kwargs) - state.cache = output - state.iteration += 1 + self.state.cache = output + self.state.iteration += 1 return module._diffusers_hook.post_forward(module, output) - def reset_state(self, module: nn.Module) -> None: - module._pyramid_attention_broadcast_state.reset() + def reset_state(self, module: torch.nn.Module) -> None: + module.state.reset() def apply_pyramid_attention_broadcast( - pipeline: DiffusionPipeline, - config: Optional[PyramidAttentionBroadcastConfig] = None, - denoiser: Optional[nn.Module] = None, + module: torch.nn.Module, + config: PyramidAttentionBroadcastConfig, ): r""" Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline. @@ -157,13 +151,10 @@ def apply_pyramid_attention_broadcast( than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process. Args: - pipeline (`DiffusionPipeline`): - The diffusion pipeline to apply Pyramid Attention Broadcast to. + module (`torch.nn.Module`): + The module to apply Pyramid Attention Broadcast to. config (`Optional[PyramidAttentionBroadcastConfig]`, `optional`, defaults to `None`): The configuration to use for Pyramid Attention Broadcast. - denoiser (`Optional[nn.Module]`, `optional`, defaults to `None`): - The denoiser module to apply Pyramid Attention Broadcast to. If `None`, the pipeline's transformer or unet - module will be used. Example: @@ -180,8 +171,10 @@ def apply_pyramid_attention_broadcast( >>> apply_pyramid_attention_broadcast(pipe, config) ``` """ - if config is None: - config = PyramidAttentionBroadcastConfig() + if config.current_timestep_callback is None: + raise ValueError( + "The `current_timestep_callback` function must be provided in the configuration to apply Pyramid Attention Broadcast." + ) if ( config.spatial_attention_block_skip_range is None @@ -195,61 +188,32 @@ def apply_pyramid_attention_broadcast( ) config.spatial_attention_block_skip_range = 2 - # Note: For diffusers models, we know that it will be either a transformer or a unet, and we also follow - # the naming convention. The option to specify a denoiser is provided for flexibility when this function - # is used outside of the diffusers library, say for a ComfyUI custom model. In that case, the user can - # specify pipeline as None but provide the denoiser module. - if denoiser is None: - denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet - - for name, module in denoiser.named_modules(): - if not isinstance(module, _ATTENTION_CLASSES): + for name, submodule in module.named_modules(): + if not isinstance(submodule, _ATTENTION_CLASSES): continue - if isinstance(module, (Attention)): - _apply_pyramid_attention_broadcast_on_attention_class(pipeline, name, module, config) - if isinstance(module, MochiAttention): - _apply_pyramid_attention_broadcast_on_mochi_attention_class(pipeline, name, module, config) - - -def apply_pyramid_attention_broadcast_on_module( - module: Attention, - skip_callback: Callable[[nn.Module], bool], -): - r""" - Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module. - - Args: - module (`torch.nn.Module`): - The module to apply Pyramid Attention Broadcast to. - skip_callback (`Callable[[nn.Module], bool]`): - A callback function that determines whether the attention computation should be skipped or not. The - callback function should return a boolean value, where `True` indicates that the attention computation - should be skipped, and `False` indicates that the attention computation should not be skipped. The callback - function will receive a torch.nn.Module containing a `_pyramid_attention_broadcast_state` attribute that - can should be used to retrieve and update the state of PAB for the given module. - """ - module._pyramid_attention_broadcast_state = PyramidAttentionBroadcastState(skip_callback=skip_callback) - hook = PyramidAttentionBroadcastHook() - add_hook_to_module(module, hook, append=True) + if isinstance(submodule, Attention): + _apply_pyramid_attention_broadcast_on_attention_class(name, module, config) + if isinstance(submodule, MochiAttention): + _apply_pyramid_attention_broadcast_on_mochi_attention_class(name, module, config) def _apply_pyramid_attention_broadcast_on_attention_class( - pipeline: DiffusionPipeline, name: str, module: Attention, config: PyramidAttentionBroadcastConfig + name: str, module: Attention, config: PyramidAttentionBroadcastConfig ) -> bool: is_spatial_self_attention = ( any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers) and config.spatial_attention_block_skip_range is not None - and not module.is_cross_attention + and not getattr(module, "is_cross_attention", False) ) is_temporal_self_attention = ( any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers) and config.temporal_attention_block_skip_range is not None - and not module.is_cross_attention + and not getattr(module, "is_cross_attention", False) ) is_cross_attention = ( any(re.search(identifier, name) is not None for identifier in config.cross_attention_block_identifiers) and config.cross_attention_block_skip_range is not None - and module.is_cross_attention + and getattr(module, "is_cross_attention", False) ) block_skip_range, timestep_skip_range, block_type = None, None, None @@ -276,12 +240,12 @@ def _apply_pyramid_attention_broadcast_on_attention_class( ) return False - def skip_callback(module: nn.Module) -> bool: + def skip_callback(module: torch.nn.Module) -> bool: pab_state = module._pyramid_attention_broadcast_state if pab_state.cache is None: return False - is_within_timestep_range = timestep_skip_range[0] < pipeline._current_timestep < timestep_skip_range[1] + is_within_timestep_range = timestep_skip_range[0] < config.current_timestep_callback() < timestep_skip_range[1] if not is_within_timestep_range: # We are still not in the phase of inference where skipping attention is possible without minimal quality # loss, as described in the paper. So, the attention computation cannot be skipped @@ -291,12 +255,34 @@ def skip_callback(module: nn.Module) -> bool: return not should_compute_attention logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}") - apply_pyramid_attention_broadcast_on_module(module, skip_callback) + _apply_pyramid_attention_broadcast(module, skip_callback) return True def _apply_pyramid_attention_broadcast_on_mochi_attention_class( - pipeline: DiffusionPipeline, name: str, module: MochiAttention, config: PyramidAttentionBroadcastConfig + name: str, module: MochiAttention, config: PyramidAttentionBroadcastConfig ) -> bool: # The same logic as Attention class works here, so just use that for now - return _apply_pyramid_attention_broadcast_on_attention_class(pipeline, name, module, config) + return _apply_pyramid_attention_broadcast_on_attention_class(name, module, config) + + +def _apply_pyramid_attention_broadcast( + module: Union[Attention, MochiAttention], + skip_callback: Callable[[torch.nn.Module], bool], +): + r""" + Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module. + + Args: + module (`torch.nn.Module`): + The module to apply Pyramid Attention Broadcast to. + skip_callback (`Callable[[nn.Module], bool]`): + A callback function that determines whether the attention computation should be skipped or not. The + callback function should return a boolean value, where `True` indicates that the attention computation + should be skipped, and `False` indicates that the attention computation should not be skipped. The callback + function will receive a torch.nn.Module containing a `_pyramid_attention_broadcast_state` attribute that + can should be used to retrieve and update the state of PAB for the given module. + """ + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = PyramidAttentionBroadcastHook(skip_callback) + registry.register_hook(hook, "pyramid_attention_broadcast") diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py deleted file mode 100644 index e3e976ddb849..000000000000 --- a/src/diffusers/models/hooks.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -from typing import Any, Dict, Tuple - -import torch - - -# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py -class ModelHook: - r""" - A hook that contains callbacks to be executed just before and after the forward method of a model. - """ - - _is_stateful = False - - def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: - r""" - Hook that is executed when a model is initialized. - - Args: - module (`torch.nn.Module`): - The module attached to this hook. - """ - return module - - def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: - r""" - Hook that is executed just before the forward method of the model. - - Args: - module (`torch.nn.Module`): - The module whose forward pass will be executed just after this event. - args (`Tuple[Any]`): - The positional arguments passed to the module. - kwargs (`Dict[Str, Any]`): - The keyword arguments passed to the module. - - Returns: - `Tuple[Tuple[Any], Dict[Str, Any]]`: - A tuple with the treated `args` and `kwargs`. - """ - return args, kwargs - - def post_forward(self, module: torch.nn.Module, output: Any) -> Any: - r""" - Hook that is executed just after the forward method of the model. - - Args: - module (`torch.nn.Module`): - The module whose forward pass been executed just before this event. - output (`Any`): - The output of the module. - - Returns: - `Any`: The processed `output`. - """ - return output - - def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: - r""" - Hook that is executed when the hook is detached from a module. - - Args: - module (`torch.nn.Module`): - The module detached from this hook. - """ - return module - - def reset_state(self, module: torch.nn.Module): - if self._is_stateful: - raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") - - -class SequentialHook(ModelHook): - r"""A hook that can contain several hooks and iterates through them at each event.""" - - def __init__(self, *hooks): - self.hooks = hooks - - def init_hook(self, module): - for hook in self.hooks: - module = hook.init_hook(module) - return module - - def pre_forward(self, module, *args, **kwargs): - for hook in self.hooks: - args, kwargs = hook.pre_forward(module, *args, **kwargs) - return args, kwargs - - def post_forward(self, module, output): - for hook in self.hooks: - output = hook.post_forward(module, output) - return output - - def detach_hook(self, module): - for hook in self.hooks: - module = hook.detach_hook(module) - return module - - def reset_state(self, module): - for hook in self.hooks: - if hook._is_stateful: - hook.reset_state(module) - - -def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False) -> torch.nn.Module: - r""" - Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove - this behavior and restore the original `forward` method, use `remove_hook_from_module`. - - - - If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks - together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class. - - - - Args: - module (`torch.nn.Module`): - The module to attach a hook to. - hook (`ModelHook`): - The hook to attach. - append (`bool`, *optional*, defaults to `False`): - Whether the hook should be chained with an existing one (if module already contains a hook) or not. - - Returns: - `torch.nn.Module`: - The same module, with the hook attached (the module is modified in place, so the result can be discarded). - """ - original_hook = hook - - if append and getattr(module, "_diffusers_hook", None) is not None: - old_hook = module._diffusers_hook - remove_hook_from_module(module) - hook = SequentialHook(old_hook, hook) - - if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"): - # If we already put some hook on this module, we replace it with the new one. - old_forward = module._old_forward - else: - old_forward = module.forward - module._old_forward = old_forward - - module = hook.init_hook(module) - module._diffusers_hook = hook - - if hasattr(original_hook, "new_forward"): - new_forward = original_hook.new_forward - else: - - def new_forward(module, *args, **kwargs): - args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) - output = module._old_forward(*args, **kwargs) - return module._diffusers_hook.post_forward(module, output) - - # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. - # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 - if "GraphModuleImpl" in str(type(module)): - module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) - else: - module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) - - return module - - -def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module: - """ - Removes any hook attached to a module via `add_hook_to_module`. - - Args: - module (`torch.nn.Module`): - The module to attach a hook to. - recurse (`bool`, defaults to `False`): - Whether to remove the hooks recursively - - Returns: - `torch.nn.Module`: - The same module, with the hook detached (the module is modified in place, so the result can be discarded). - """ - - if hasattr(module, "_diffusers_hook"): - module._diffusers_hook.detach_hook(module) - delattr(module, "_diffusers_hook") - - if hasattr(module, "_old_forward"): - # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. - # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 - if "GraphModuleImpl" in str(type(module)): - module.__class__.forward = module._old_forward - else: - module.forward = module._old_forward - delattr(module, "_old_forward") - - if recurse: - for child in module.children(): - remove_hook_from_module(child, recurse) - - return module - - -def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False): - """ - Resets the state of all stateful hooks attached to a module. - - Args: - module (`torch.nn.Module`): - The module to reset the stateful hooks from. - """ - if hasattr(module, "_diffusers_hook") and ( - module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook) - ): - module._diffusers_hook.reset_state(module) - - if recurse: - for child in module.children(): - reset_stateful_hooks(child, recurse) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 8c04136b3370..ce291e5ceb45 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -58,11 +58,6 @@ "StableDiffusionMixin", "ImagePipelineOutput", ] - _import_structure["pyramid_attention_broadcast_utils"] = [ - "PyramidAttentionBroadcastConfig", - "apply_pyramid_attention_broadcast", - "apply_pyramid_attention_broadcast_on_module", - ] _import_structure["deprecated"].extend( [ "PNDMPipeline", @@ -461,11 +456,6 @@ ImagePipelineOutput, StableDiffusionMixin, ) - from .pyramid_attention_broadcast_utils import ( - PyramidAttentionBroadcastConfig, - apply_pyramid_attention_broadcast, - apply_pyramid_attention_broadcast_on_module, - ) try: if not (is_torch_available() and is_librosa_available()): From 95c814826de6a5ffb97d203ee9d5fc6c1ddb21bb Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 13 Jan 2025 06:46:05 +0100 Subject: [PATCH 36/61] make fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 19 +++++++++++++++ .../dummy_torch_and_transformers_objects.py | 23 ------------------- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 4b6ac10385cf..1be509c19014 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,6 +2,25 @@ from ..utils import DummyObject, requires_backends +class PyramidAttentionBroadcastConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +def apply_pyramid_attention_broadcast(*args, **kwargs): + requires_backends(apply_pyramid_attention_broadcast, ["torch"]) + + class AllegroTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 9876af2bafe9..9b36be9e0604 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1277,21 +1277,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class PyramidAttentionBroadcastConfig(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class ReduxImageEncoder(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -2550,11 +2535,3 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) - - -def apply_pyramid_attention_broadcast(*args, **kwargs): - requires_backends(apply_pyramid_attention_broadcast, ["torch", "transformers"]) - - -def apply_pyramid_attention_broadcast_on_module(*args, **kwargs): - requires_backends(apply_pyramid_attention_broadcast_on_module, ["torch", "transformers"]) From 76afc6a9ab86f780125ce0101689a09ca9f34885 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 13 Jan 2025 06:51:02 +0100 Subject: [PATCH 37/61] update docs --- docs/source/en/api/cache.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md index 5a747bec64cf..fd35ccad3cf3 100644 --- a/docs/source/en/api/cache.md +++ b/docs/source/en/api/cache.md @@ -24,5 +24,3 @@ Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some b [[autodoc]] PyramidAttentionBroadcastConfig [[autodoc]] apply_pyramid_attention_broadcast - -[[autodoc]] apply_pyramid_attention_broadcast_on_module From fb661678b128615354d5e3df07c8381252fc648a Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 13 Jan 2025 07:00:00 +0100 Subject: [PATCH 38/61] fixes --- src/diffusers/__init__.py | 4 +++- src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/hooks.py | 6 ++++++ src/diffusers/pipelines/allegro/pipeline_allegro.py | 4 ++-- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 4 ++-- .../pipelines/cogvideo/pipeline_cogvideox_fun_control.py | 4 ++-- .../pipelines/cogvideo/pipeline_cogvideox_image2video.py | 4 ++-- .../pipelines/cogvideo/pipeline_cogvideox_video2video.py | 4 ++-- src/diffusers/pipelines/flux/pipeline_flux.py | 4 ++-- .../pipelines/hunyuan_video/pipeline_hunyuan_video.py | 4 ++-- src/diffusers/pipelines/latte/pipeline_latte.py | 4 ++-- src/diffusers/pipelines/mochi/pipeline_mochi.py | 4 ++-- 12 files changed, 28 insertions(+), 19 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 131fa1cc8821..fc0ff0989a93 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -28,6 +28,7 @@ _import_structure = { "configuration_utils": ["ConfigMixin"], + "hooks": [], "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], @@ -77,6 +78,7 @@ else: _import_structure["hooks"].extend( [ + "HookRegistry", "PyramidAttentionBroadcastConfig", "apply_pyramid_attention_broadcast", ] @@ -592,7 +594,7 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: - from .hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .models import ( AllegroTransformer3DModel, AsymmetricAutoencoderKL, diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 148e6c8fdc97..e6492646e45e 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -2,4 +2,5 @@ if is_torch_available(): + from .hooks import HookRegistry from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index d9f257aac6c6..beffc34374b9 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -147,6 +147,12 @@ def remove_hook(self, name: str) -> None: del self.hooks[name] self._hook_order.remove(name) + def reset_stateful_hooks(self): + for hook_name in self._hook_order: + hook = self.hooks[hook_name] + if hook._is_stateful: + hook.reset_state(self._module_ref) + @classmethod def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry": if not hasattr(module, "_diffusers_hook"): diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 36bf2a4a15da..a5270d2fd6a4 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -24,9 +24,9 @@ from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...hooks import HookRegistry from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro from ...models.embeddings import get_3d_rotary_pos_embed_allegro -from ...models.hooks import reset_stateful_hooks from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -948,7 +948,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - reset_stateful_hooks(self.transformer, recurse=True) + HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 5f59cb874b10..ebe64b03442b 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -21,10 +21,10 @@ from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...hooks import HookRegistry from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed -from ...models.hooks import reset_stateful_hooks from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring @@ -778,7 +778,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - reset_stateful_hooks(self.transformer, recurse=True) + HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 4a5848ebbb9f..23a9166521c7 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -22,10 +22,10 @@ from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...hooks import HookRegistry from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed -from ...models.hooks import reset_stateful_hooks from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_torch_xla_available, logging, replace_example_docstring @@ -831,7 +831,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - reset_stateful_hooks(self.transformer, recurse=True) + HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index b84bf2f7b0a9..2589705d978c 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -22,11 +22,11 @@ from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...hooks import HookRegistry from ...image_processor import PipelineImageInput from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed -from ...models.hooks import reset_stateful_hooks from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import ( @@ -892,7 +892,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - reset_stateful_hooks(self.transformer, recurse=True) + HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 51d945f8d0a0..76cd4a05caeb 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -22,10 +22,10 @@ from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...hooks import HookRegistry from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed -from ...models.hooks import reset_stateful_hooks from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring @@ -857,7 +857,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - reset_stateful_hooks(self.transformer, recurse=True) + HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index e27473ac4e13..14d45c273b71 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -26,10 +26,10 @@ T5TokenizerFast, ) +from ...hooks import HookRegistry from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, FluxTransformer2DModel -from ...models.hooks import reset_stateful_hooks from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, @@ -971,7 +971,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - reset_stateful_hooks(self.transformer, recurse=True) + HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index c8e30183cb81..f29e461f9182 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -20,9 +20,9 @@ from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...hooks import HookRegistry from ...loaders import HunyuanVideoLoraLoaderMixin from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel -from ...models.hooks import reset_stateful_hooks from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor @@ -692,7 +692,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - reset_stateful_hooks(self.transformer, recurse=True) + HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index 95cb114d62e5..51e575bda3ce 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -24,8 +24,8 @@ from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...hooks import HookRegistry from ...models import AutoencoderKL, LatteTransformer3DModel -from ...models.hooks import reset_stateful_hooks from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -861,7 +861,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - reset_stateful_hooks(self.transformer, recurse=True) + HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 6c224ae97c26..ec2a77a69256 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -20,9 +20,9 @@ from transformers import T5EncoderModel, T5TokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...hooks import HookRegistry from ...loaders import Mochi1LoraLoaderMixin from ...models import AutoencoderKLHunyuanVideo, MochiTransformer3DModel -from ...models.hooks import reset_stateful_hooks from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( is_torch_xla_available, @@ -734,7 +734,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - reset_stateful_hooks(self.transformer, recurse=True) + HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (video,) From 1040c911e4f760e40a857d533a6e72dd19962ac5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 13 Jan 2025 08:12:09 +0100 Subject: [PATCH 39/61] more fixes --- src/diffusers/hooks/hooks.py | 17 ++++++++--- .../hooks/pyramid_attention_broadcast.py | 28 +++++++++++++------ src/diffusers/models/attention_processor.py | 2 -- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index beffc34374b9..d9bcddb1b2bc 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -121,7 +121,12 @@ def register_hook(self, hook: ModelHook, name: str) -> None: self._module_ref = hook.initialize_hook(self._module_ref) if hasattr(hook, "new_forward"): - new_forward = hook.new_forward + rewritten_forward = hook.new_forward + + def new_forward(module, *args, **kwargs): + args, kwargs = hook.pre_forward(module, *args, **kwargs) + output = rewritten_forward(module, *args, **kwargs) + return hook.post_forward(module, output) else: def new_forward(module, *args, **kwargs): @@ -129,8 +134,7 @@ def new_forward(module, *args, **kwargs): output = old_forward(*args, **kwargs) return hook.post_forward(module, output) - new_forward = functools.update_wrapper(new_forward, old_forward) - self._module_ref.forward = new_forward.__get__(self._module_ref) + self._module_ref.forward = functools.update_wrapper(functools.partial(new_forward, self._module_ref), old_forward) self.hooks[name] = hook self._hook_order.append(name) @@ -147,11 +151,16 @@ def remove_hook(self, name: str) -> None: del self.hooks[name] self._hook_order.remove(name) - def reset_stateful_hooks(self): + def reset_stateful_hooks(self, recurse: bool = True) -> None: for hook_name in self._hook_order: hook = self.hooks[hook_name] if hook._is_stateful: hook.reset_state(self._module_ref) + + if recurse: + for module in self._module_ref.modules(): + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook.reset_stateful_hooks(recurse=False) @classmethod def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry": diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index 3f0080165927..2ba71da4d203 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -106,6 +106,14 @@ def __init__(self) -> None: def reset(self): self.iteration = 0 self.cache = None + + def __repr__(self): + cache_repr = "" + if self.cache is None: + cache_repr = "None" + else: + cache_repr = f"Tensor(shape={self.cache.shape}, dtype={self.cache.dtype})" + return f"PyramidAttentionBroadcastState(iteration={self.iteration}, cache={cache_repr})" class PyramidAttentionBroadcastHook(ModelHook): @@ -120,21 +128,21 @@ def __init__(self, skip_callback: Callable[[torch.nn.Module], bool]) -> None: def initialize_hook(self, module): self.state = PyramidAttentionBroadcastState() + return module def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: - args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) - if self.skip_callback(module): - output = module._pyramid_attention_broadcast_state.cache + output = self.state.cache else: output = module._old_forward(*args, **kwargs) self.state.cache = output self.state.iteration += 1 - return module._diffusers_hook.post_forward(module, output) + return output def reset_state(self, module: torch.nn.Module) -> None: - module.state.reset() + self.state.reset() + return module def apply_pyramid_attention_broadcast( @@ -168,7 +176,7 @@ def apply_pyramid_attention_broadcast( >>> config = PyramidAttentionBroadcastConfig( ... spatial_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(100, 800) ... ) - >>> apply_pyramid_attention_broadcast(pipe, config) + >>> apply_pyramid_attention_broadcast(pipe.transformer, config) ``` """ if config.current_timestep_callback is None: @@ -192,9 +200,9 @@ def apply_pyramid_attention_broadcast( if not isinstance(submodule, _ATTENTION_CLASSES): continue if isinstance(submodule, Attention): - _apply_pyramid_attention_broadcast_on_attention_class(name, module, config) + _apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config) if isinstance(submodule, MochiAttention): - _apply_pyramid_attention_broadcast_on_mochi_attention_class(name, module, config) + _apply_pyramid_attention_broadcast_on_mochi_attention_class(name, submodule, config) def _apply_pyramid_attention_broadcast_on_attention_class( @@ -241,7 +249,9 @@ def _apply_pyramid_attention_broadcast_on_attention_class( return False def skip_callback(module: torch.nn.Module) -> bool: - pab_state = module._pyramid_attention_broadcast_state + hook: PyramidAttentionBroadcastHook = module._diffusers_hook.get_hook("pyramid_attention_broadcast") + pab_state: PyramidAttentionBroadcastState = hook.state + if pab_state.cache is None: return False diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c91ddd95c861..4d7ae6bef26e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -930,8 +930,6 @@ def __init__( self.out_dim = out_dim if out_dim is not None else query_dim self.out_context_dim = out_context_dim if out_context_dim else query_dim self.context_pre_only = context_pre_only - # TODO(aryan): Maybe try to improve the checks in PAB instead - self.is_cross_attention = False self.heads = out_dim // dim_head if out_dim is not None else heads From ffbabb598b2cb77f15253e5b4a57e0b88ad3f9ad Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 13 Jan 2025 08:12:22 +0100 Subject: [PATCH 40/61] make style --- src/diffusers/hooks/hooks.py | 6 ++++-- src/diffusers/hooks/pyramid_attention_broadcast.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index d9bcddb1b2bc..ee52e11b8b6d 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -134,7 +134,9 @@ def new_forward(module, *args, **kwargs): output = old_forward(*args, **kwargs) return hook.post_forward(module, output) - self._module_ref.forward = functools.update_wrapper(functools.partial(new_forward, self._module_ref), old_forward) + self._module_ref.forward = functools.update_wrapper( + functools.partial(new_forward, self._module_ref), old_forward + ) self.hooks[name] = hook self._hook_order.append(name) @@ -156,7 +158,7 @@ def reset_stateful_hooks(self, recurse: bool = True) -> None: hook = self.hooks[hook_name] if hook._is_stateful: hook.reset_state(self._module_ref) - + if recurse: for module in self._module_ref.modules(): if hasattr(module, "_diffusers_hook"): diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index 2ba71da4d203..19b9674a5a90 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -106,7 +106,7 @@ def __init__(self) -> None: def reset(self): self.iteration = 0 self.cache = None - + def __repr__(self): cache_repr = "" if self.cache is None: @@ -251,7 +251,7 @@ def _apply_pyramid_attention_broadcast_on_attention_class( def skip_callback(module: torch.nn.Module) -> bool: hook: PyramidAttentionBroadcastHook = module._diffusers_hook.get_hook("pyramid_attention_broadcast") pab_state: PyramidAttentionBroadcastState = hook.state - + if pab_state.cache is None: return False From 1b92b1dd30f4194f082a0570be6be8674a4fbb2c Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 13 Jan 2025 08:31:29 +0100 Subject: [PATCH 41/61] update tests --- tests/pipelines/test_pipelines_common.py | 37 ++++++++++++++---------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index f09f3f8cc426..cb5e3fe5cabe 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -30,6 +30,7 @@ UNet2DConditionModel, apply_pyramid_attention_broadcast, ) +from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin from diffusers.models.attention_processor import AttnProcessor @@ -38,7 +39,6 @@ from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet from diffusers.models.unets.unet_motion_model import UNetMotionModel from diffusers.pipelines.pipeline_utils import StableDiffusionMixin -from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastHook from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available @@ -2298,7 +2298,9 @@ def test_pyramid_attention_broadcast_layers(self): pipe = self.pipeline_class(**components) pipe.set_progress_bar_config(disable=None) - apply_pyramid_attention_broadcast(pipe, self.pab_config) + self.pab_config.current_timestep_callback = lambda: pipe._current_timestep + denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet + apply_pyramid_attention_broadcast(denoiser, self.pab_config) expected_hooks = 0 if self.pab_config.spatial_attention_block_skip_range is not None: @@ -2312,30 +2314,30 @@ def test_pyramid_attention_broadcast_layers(self): count = 0 for module in denoiser.modules(): if hasattr(module, "_diffusers_hook"): + hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast") + if hook is None: + continue count += 1 self.assertTrue( - isinstance(module._diffusers_hook, PyramidAttentionBroadcastHook), + isinstance(hook, PyramidAttentionBroadcastHook), "Hook should be of type PyramidAttentionBroadcastHook.", ) - self.assertTrue( - hasattr(module, "_pyramid_attention_broadcast_state"), - "PAB state should be initialized when enabled.", - ) - self.assertTrue( - module._pyramid_attention_broadcast_state.cache is None, "Cache should be None at initialization." - ) + self.assertTrue(hook.state.cache is None, "Cache should be None at initialization.") self.assertEqual(count, expected_hooks, "Number of hooks should match the expected number.") # Perform dummy inference step to ensure state is updated def pab_state_check_callback(pipe, i, t, kwargs): for module in denoiser.modules(): if hasattr(module, "_diffusers_hook"): + hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast") + if hook is None: + continue self.assertTrue( - module._pyramid_attention_broadcast_state.cache is not None, + hook.state.cache is not None, "Cache should have updated during inference.", ) self.assertTrue( - module._pyramid_attention_broadcast_state.iteration == i + 1, + hook.state.iteration == i + 1, "Hook iteration state should have updated during inference.", ) return {} @@ -2348,12 +2350,15 @@ def pab_state_check_callback(pipe, i, t, kwargs): # After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states for module in denoiser.modules(): if hasattr(module, "_diffusers_hook"): + hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast") + if hook is None: + continue self.assertTrue( - module._pyramid_attention_broadcast_state.cache is None, + hook.state.cache is None, "Cache should be reset to None after inference.", ) self.assertTrue( - module._pyramid_attention_broadcast_state.iteration == 0, + hook.state.iteration == 0, "Iteration should be reset to 0 after inference.", ) @@ -2374,7 +2379,9 @@ def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2) original_image_slice = output.flatten() original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:])) - apply_pyramid_attention_broadcast(pipe, self.pab_config) + self.pab_config.current_timestep_callback = lambda: pipe._current_timestep + denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet + apply_pyramid_attention_broadcast(denoiser, self.pab_config) inputs = self.get_dummy_inputs(device) inputs["num_inference_steps"] = 4 From 88d917dc4fe537f4a142d14fd4e7e31281ffe6cd Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 13 Jan 2025 08:31:55 +0100 Subject: [PATCH 42/61] update code example --- src/diffusers/hooks/hooks.py | 6 +++--- src/diffusers/hooks/pyramid_attention_broadcast.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index ee52e11b8b6d..9d027b28cbda 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -13,7 +13,7 @@ # limitations under the License. import functools -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple import torch @@ -141,9 +141,9 @@ def new_forward(module, *args, **kwargs): self.hooks[name] = hook self._hook_order.append(name) - def get_hook(self, name: str) -> ModelHook: + def get_hook(self, name: str) -> Optional[ModelHook]: if name not in self.hooks.keys(): - raise ValueError(f"Hook with name {name} not found.") + return None return self.hooks[name] def remove_hook(self, name: str) -> None: diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index 19b9674a5a90..e0dce4544581 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -169,12 +169,15 @@ def apply_pyramid_attention_broadcast( ```python >>> import torch >>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + >>> from diffusers.utils import export_to_video >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> config = PyramidAttentionBroadcastConfig( - ... spatial_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(100, 800) + ... spatial_attention_block_skip_range=2, + ... spatial_attention_timestep_skip_range=(100, 800), + ... current_timestep_callback=lambda: pipe._current_timestep, ... ) >>> apply_pyramid_attention_broadcast(pipe.transformer, config) ``` From e4d8b12d15c48816e10d0387a1d7a026c426c880 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 13 Jan 2025 08:44:16 +0100 Subject: [PATCH 43/61] make fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 1be509c19014..88da30a5a7c7 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class HookRegistry(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PyramidAttentionBroadcastConfig(metaclass=DummyObject): _backends = ["torch"] From ae8bd9948be0246462056dd22b2b0716d4d81915 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 15 Jan 2025 09:22:30 +0100 Subject: [PATCH 44/61] refactor based on reviews --- .../hooks/pyramid_attention_broadcast.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index e0dce4544581..ab782a8287df 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -201,11 +201,12 @@ def apply_pyramid_attention_broadcast( for name, submodule in module.named_modules(): if not isinstance(submodule, _ATTENTION_CLASSES): + # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB + # cannot be applied to this layer. For custom layers, users can extend this functionality and implement + # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`. continue - if isinstance(submodule, Attention): + if isinstance(submodule, (Attention, MochiAttention)): _apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config) - if isinstance(submodule, MochiAttention): - _apply_pyramid_attention_broadcast_on_mochi_attention_class(name, submodule, config) def _apply_pyramid_attention_broadcast_on_attention_class( @@ -246,8 +247,7 @@ def _apply_pyramid_attention_broadcast_on_attention_class( f'Unable to apply Pyramid Attention Broadcast to the selected layer: "{name}" because it does ' f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, " f"however, that this layer may still be valid for applying PAB. Please specify the correct " - f"block identifiers in the configuration or use the specialized `apply_pyramid_attention_broadcast_on_module` " - f"function to apply PAB to this layer." + f"block identifiers in the configuration." ) return False @@ -272,13 +272,6 @@ def skip_callback(module: torch.nn.Module) -> bool: return True -def _apply_pyramid_attention_broadcast_on_mochi_attention_class( - name: str, module: MochiAttention, config: PyramidAttentionBroadcastConfig -) -> bool: - # The same logic as Attention class works here, so just use that for now - return _apply_pyramid_attention_broadcast_on_attention_class(name, module, config) - - def _apply_pyramid_attention_broadcast( module: Union[Attention, MochiAttention], skip_callback: Callable[[torch.nn.Module], bool], From a9ee5a42d66ff07e18d35ca39104e4aabc6a6e90 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 15 Jan 2025 09:29:17 +0100 Subject: [PATCH 45/61] use maybe_free_model_hooks --- src/diffusers/pipelines/allegro/pipeline_allegro.py | 2 -- src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py | 2 -- .../pipelines/cogvideo/pipeline_cogvideox_fun_control.py | 2 -- .../pipelines/cogvideo/pipeline_cogvideox_image2video.py | 2 -- .../pipelines/cogvideo/pipeline_cogvideox_video2video.py | 2 -- src/diffusers/pipelines/flux/pipeline_flux.py | 2 -- .../pipelines/hunyuan_video/pipeline_hunyuan_video.py | 2 -- src/diffusers/pipelines/latte/pipeline_latte.py | 2 -- src/diffusers/pipelines/mochi/pipeline_mochi.py | 2 -- src/diffusers/pipelines/pipeline_utils.py | 5 +++++ 10 files changed, 5 insertions(+), 18 deletions(-) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index a5270d2fd6a4..83d3e8470e63 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -24,7 +24,6 @@ from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...hooks import HookRegistry from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro from ...models.embeddings import get_3d_rotary_pos_embed_allegro from ...pipelines.pipeline_utils import DiffusionPipeline @@ -948,7 +947,6 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index ebe64b03442b..dc688ee86160 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -21,7 +21,6 @@ from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...hooks import HookRegistry from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed @@ -778,7 +777,6 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 23a9166521c7..aef5c24ab667 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -22,7 +22,6 @@ from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...hooks import HookRegistry from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed @@ -831,7 +830,6 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 2589705d978c..2e43d8535e43 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -22,7 +22,6 @@ from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...hooks import HookRegistry from ...image_processor import PipelineImageInput from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel @@ -892,7 +891,6 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 76cd4a05caeb..a53fbe651958 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -22,7 +22,6 @@ from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...hooks import HookRegistry from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed @@ -857,7 +856,6 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 14d45c273b71..c747841a798a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -26,7 +26,6 @@ T5TokenizerFast, ) -from ...hooks import HookRegistry from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, FluxTransformer2DModel @@ -971,7 +970,6 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index f29e461f9182..0f776334eae0 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -20,7 +20,6 @@ from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...hooks import HookRegistry from ...loaders import HunyuanVideoLoraLoaderMixin from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -692,7 +691,6 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index 35e1f775ebc0..6ee18e0d1406 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -24,7 +24,6 @@ from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...hooks import HookRegistry from ...models import AutoencoderKL, LatteTransformer3DModel from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import KarrasDiffusionSchedulers @@ -869,7 +868,6 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index ec2a77a69256..ed3ea3e59f25 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -20,7 +20,6 @@ from transformers import T5EncoderModel, T5TokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...hooks import HookRegistry from ...loaders import Mochi1LoraLoaderMixin from ...models import AutoencoderKLHunyuanVideo, MochiTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -734,7 +733,6 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - HookRegistry.check_if_exists_or_initialize(self.transformer).reset_stateful_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 3cafb77e5d63..022b56cc9af6 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -44,6 +44,7 @@ from .. import __version__ from ..configuration_utils import ConfigMixin +from ..hooks import HookRegistry from ..models import AutoencoderKL from ..models.attention_processor import FusedAttnProcessor2_0 from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin @@ -1138,6 +1139,10 @@ def maybe_free_model_hooks(self): is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions correctly when applying enable_model_cpu_offload. """ + for name, component in self.components.items(): + if name in ("transformer", "unet"): + HookRegistry.check_if_exists_or_initialize(component).reset_stateful_hooks(recurse=True) + if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: # `enable_model_cpu_offload` has not be called, so silently do nothing return From 1a596883329d8fddd930ed1f0a861c36d0c9f282 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 15 Jan 2025 09:50:59 +0100 Subject: [PATCH 46/61] CacheMixin --- src/diffusers/models/cache_utils.py | 48 +++++++++++++++++++ .../transformers/cogvideox_transformer_3d.py | 3 +- .../transformers/latte_transformer_3d.py | 3 +- .../transformers/transformer_allegro.py | 3 +- .../models/transformers/transformer_flux.py | 3 +- .../transformers/transformer_hunyuan_video.py | 3 +- .../models/transformers/transformer_mochi.py | 3 +- src/diffusers/pipelines/pipeline_utils.py | 13 +++-- 8 files changed, 69 insertions(+), 10 deletions(-) create mode 100644 src/diffusers/models/cache_utils.py diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py new file mode 100644 index 000000000000..cc6cc3e290b7 --- /dev/null +++ b/src/diffusers/models/cache_utils.py @@ -0,0 +1,48 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + + +CacheConfig = Union[PyramidAttentionBroadcastConfig] + + +class CacheMixin: + _cache_config: CacheConfig = None + + @property + def is_cache_enabled(self) -> bool: + return self._cache_config is not None + + def enable_cache(self, config: CacheConfig) -> None: + if isinstance(config, PyramidAttentionBroadcastConfig): + apply_pyramid_attention_broadcast(self.model, config) + else: + raise ValueError(f"Cache config {type(config)} is not supported.") + self._cache_config = config + + def disable_cache(self) -> None: + if self._cache_config is None: + raise ValueError("Caching techniques have not been enabled.") + if isinstance(self._cache_config, PyramidAttentionBroadcastConfig): + registry = HookRegistry.check_if_exists_or_initialize(self) + registry.remove_hook("pyramid_attention_broadcast") + else: + raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") + self._cache_config = None + + def reset_stateful_cache(self, recurse: bool = True) -> None: + HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 51634780692d..f6edaceace20 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -24,6 +24,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 +from ..cache_utils import CacheMixin from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -156,7 +157,7 @@ def forward( return hidden_states, encoder_hidden_states -class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): """ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index cc066e8fe3cc..40b78c6a71be 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -20,13 +20,14 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid from ..attention import BasicTransformerBlock +from ..cache_utils import CacheMixin from ..embeddings import PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle -class LatteTransformer3DModel(ModelMixin, ConfigMixin): +class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): _supports_gradient_checkpointing = True """ diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 81039fd49e0d..702952c6a5d4 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -24,6 +24,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import AllegroAttnProcessor2_0, Attention +from ..cache_utils import CacheMixin from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -172,7 +173,7 @@ def forward( return hidden_states -class AllegroTransformer3DModel(ModelMixin, ConfigMixin): +class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): _supports_gradient_checkpointing = True """ diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index f5e92700b2f3..ec36c18fb032 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -35,6 +35,7 @@ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph +from ..cache_utils import CacheMixin from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -227,7 +228,7 @@ def forward( class FluxTransformer2DModel( - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin ): """ The Transformer model introduced in Flux. diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 4495623119e5..7e72c8f6a7a1 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -25,6 +25,7 @@ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor +from ..cache_utils import CacheMixin from ..embeddings import ( CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, @@ -502,7 +503,7 @@ def forward( return hidden_states, encoder_hidden_states -class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 8763ea450253..53b5a436e636 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -25,6 +25,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 +from ..cache_utils import CacheMixin from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -305,7 +306,7 @@ def forward( @maybe_allow_in_graph -class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview). diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 022b56cc9af6..d3f064e190d9 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1134,10 +1134,15 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t def maybe_free_model_hooks(self): r""" - Function that offloads all components, removes all model hooks that were added when using - `enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function - is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it - functions correctly when applying enable_model_cpu_offload. + Method that performs the following: + - Offloads all components. + - Removes all model hooks that were added when using `enable_model_cpu_offload`, and then applies them again. + In case the model has not been offloaded, this function is a no-op. + - Resets stateful diffusers hooks of denoiser components if they were added with + [`~hooks.HookRegistry.register_hook`]. + + Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions + correctly when applying `enable_model_cpu_offload`. """ for name, component in self.components.items(): if name in ("transformer", "unet"): From c8616a6dabff60213eb97de745b5a826798d3913 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 15 Jan 2025 09:56:06 +0100 Subject: [PATCH 47/61] make style --- .../hooks/pyramid_attention_broadcast.py | 16 ++++++++++++++++ src/diffusers/models/cache_utils.py | 5 ++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index ab782a8287df..c06fef39a3e3 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -85,6 +85,22 @@ class PyramidAttentionBroadcastConfig: # TODO(aryan): add PAB for MLP layers (very limited speedup from testing with original codebase # so not added for now) + def __repr__(self) -> str: + return ( + f"PyramidAttentionBroadcastConfig(" + f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range}, " + f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range}, " + f" cross_attention_block_skip_range={self.cross_attention_block_skip_range}, " + f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range}, " + f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range}, " + f" cross_attention_timestep_skip_range={self.cross_attention_timestep_skip_range}, " + f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers}, " + f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers}, " + f" cross_attention_block_identifiers={self.cross_attention_block_identifiers}, " + f" current_timestep_callback={self.current_timestep_callback}" + ")" + ) + class PyramidAttentionBroadcastState: r""" diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index cc6cc3e290b7..2583c296409c 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -32,17 +32,20 @@ def enable_cache(self, config: CacheConfig) -> None: apply_pyramid_attention_broadcast(self.model, config) else: raise ValueError(f"Cache config {type(config)} is not supported.") + self._cache_config = config def disable_cache(self) -> None: if self._cache_config is None: raise ValueError("Caching techniques have not been enabled.") + if isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry = HookRegistry.check_if_exists_or_initialize(self) registry.remove_hook("pyramid_attention_broadcast") else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") + self._cache_config = None - def reset_stateful_cache(self, recurse: bool = True) -> None: + def _reset_stateful_cache(self, recurse: bool = True) -> None: HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse) From 08a209d1e11e784f00d2019f892a69e01c369744 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 15 Jan 2025 10:08:02 +0100 Subject: [PATCH 48/61] update --- .../hooks/pyramid_attention_broadcast.py | 20 +++++++++---------- src/diffusers/models/cache_utils.py | 9 ++++++++- src/diffusers/pipelines/pipeline_utils.py | 5 ++--- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index c06fef39a3e3..12b81385dcba 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -88,16 +88,16 @@ class PyramidAttentionBroadcastConfig: def __repr__(self) -> str: return ( f"PyramidAttentionBroadcastConfig(" - f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range}, " - f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range}, " - f" cross_attention_block_skip_range={self.cross_attention_block_skip_range}, " - f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range}, " - f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range}, " - f" cross_attention_timestep_skip_range={self.cross_attention_timestep_skip_range}, " - f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers}, " - f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers}, " - f" cross_attention_block_identifiers={self.cross_attention_block_identifiers}, " - f" current_timestep_callback={self.current_timestep_callback}" + f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n" + f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n" + f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n" + f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n" + f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n" + f" cross_attention_timestep_skip_range={self.cross_attention_timestep_skip_range},\n" + f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n" + f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n" + f" cross_attention_block_identifiers={self.cross_attention_block_identifiers},\n" + f" current_timestep_callback={self.current_timestep_callback}\n" ")" ) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 2583c296409c..da00155ab6c7 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -21,6 +21,13 @@ class CacheMixin: + r""" + A class for enable/disabling caching techniques on diffusion models. + + Supported caching techniques: + - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) + """ + _cache_config: CacheConfig = None @property @@ -29,7 +36,7 @@ def is_cache_enabled(self) -> bool: def enable_cache(self, config: CacheConfig) -> None: if isinstance(config, PyramidAttentionBroadcastConfig): - apply_pyramid_attention_broadcast(self.model, config) + apply_pyramid_attention_broadcast(self, config) else: raise ValueError(f"Cache config {type(config)} is not supported.") diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d3f064e190d9..c696c050f934 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -44,7 +44,6 @@ from .. import __version__ from ..configuration_utils import ConfigMixin -from ..hooks import HookRegistry from ..models import AutoencoderKL from ..models.attention_processor import FusedAttnProcessor2_0 from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin @@ -1145,8 +1144,8 @@ def maybe_free_model_hooks(self): correctly when applying `enable_model_cpu_offload`. """ for name, component in self.components.items(): - if name in ("transformer", "unet"): - HookRegistry.check_if_exists_or_initialize(component).reset_stateful_hooks(recurse=True) + if name in ("transformer", "unet") and hasattr(component, "_reset_stateful_cache"): + component._reset_stateful_cache() if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: # `enable_model_cpu_offload` has not be called, so silently do nothing From 15e645df1564d4c45da70c07cb2da7973eb7bc83 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 15 Jan 2025 10:35:00 +0100 Subject: [PATCH 49/61] add current_timestep property; update docs --- docs/source/en/api/cache.md | 23 +++++++++++++++ src/diffusers/__init__.py | 2 ++ .../hooks/pyramid_attention_broadcast.py | 2 +- src/diffusers/models/__init__.py | 2 ++ src/diffusers/models/cache_utils.py | 28 ++++++++++++++++++- .../pipelines/allegro/pipeline_allegro.py | 4 +++ .../pipelines/cogvideo/pipeline_cogvideox.py | 4 +++ .../pipeline_cogvideox_fun_control.py | 4 +++ .../pipeline_cogvideox_image2video.py | 4 +++ .../pipeline_cogvideox_video2video.py | 4 +++ src/diffusers/pipelines/flux/pipeline_flux.py | 4 +++ .../hunyuan_video/pipeline_hunyuan_video.py | 4 +++ .../pipelines/latte/pipeline_latte.py | 4 +++ .../pipelines/mochi/pipeline_mochi.py | 4 +++ 14 files changed, 91 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md index fd35ccad3cf3..403dbf88b431 100644 --- a/docs/source/en/api/cache.md +++ b/docs/source/en/api/cache.md @@ -19,6 +19,29 @@ Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffus Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request. +```python +import torch +from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of +# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention +# broadcast is active, leader to slower inference speeds. However, large intervals can lead to +# poorer quality of generated videos. +config = PyramidAttentionBroadcastConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(100, 800), + current_timestep_callback=lambda: pipe.current_timestep, +) +pipe.transformer.enable_cache(config) +``` + +### CacheMixin + +[[autodoc]] CacheMixin + ### PyramidAttentionBroadcastConfig [[autodoc]] PyramidAttentionBroadcastConfig diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fc0ff0989a93..4b88273147a1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -98,6 +98,7 @@ "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", "AutoencoderTiny", + "CacheMixin", "CogVideoXTransformer3DModel", "CogView3PlusTransformer2DModel", "ConsistencyDecoderVAE", @@ -609,6 +610,7 @@ AutoencoderKLTemporalDecoder, AutoencoderOobleck, AutoencoderTiny, + CacheMixin, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, ConsistencyDecoderVAE, diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index 12b81385dcba..d5ee4568d8bf 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -193,7 +193,7 @@ def apply_pyramid_attention_broadcast( >>> config = PyramidAttentionBroadcastConfig( ... spatial_attention_block_skip_range=2, ... spatial_attention_timestep_skip_range=(100, 800), - ... current_timestep_callback=lambda: pipe._current_timestep, + ... current_timestep_callback=lambda: pipe.current_timestep, ... ) >>> apply_pyramid_attention_broadcast(pipe.transformer, config) ``` diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 01e67b01d91a..e351c0f8fa3b 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -39,6 +39,7 @@ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] + _import_structure["cache_utils"] = ["CacheMixin"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"] _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] _import_structure["controlnets.controlnet_hunyuan"] = [ @@ -108,6 +109,7 @@ ConsistencyDecoderVAE, VQModel, ) + from .cache_utils import CacheMixin from .controlnets import ( ControlNetModel, ControlNetUnionModel, diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index da00155ab6c7..0cc4fc8129dd 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -25,7 +25,7 @@ class CacheMixin: A class for enable/disabling caching techniques on diffusion models. Supported caching techniques: - - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) + - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) """ _cache_config: CacheConfig = None @@ -35,6 +35,32 @@ def is_cache_enabled(self) -> bool: return self._cache_config is not None def enable_cache(self, config: CacheConfig) -> None: + r""" + Enable caching techniques on the model. + + Args: + config (`Union[PyramidAttentionBroadcastConfig]`): + The configuration for applying the caching technique. Currently supported caching techniques are: + - `PyramidAttentionBroadcastConfig` + + Example: + + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig + + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> config = PyramidAttentionBroadcastConfig( + ... spatial_attention_block_skip_range=2, + ... spatial_attention_timestep_skip_range=(100, 800), + ... current_timestep_callback=lambda: pipe.current_timestep, + ... ) + >>> pipe.transformer.enable_cache(config) + ``` + """ + if isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) else: diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 83d3e8470e63..cb36a7a672de 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -683,6 +683,10 @@ def guidance_scale(self): def num_timesteps(self): return self._num_timesteps + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index dc688ee86160..99ae9025cd3e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -494,6 +494,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index aef5c24ab667..e37574ec9cb2 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -540,6 +540,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 2e43d8535e43..59d7c4cad547 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -591,6 +591,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index a53fbe651958..c4dc7e574f7e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -564,6 +564,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index c747841a798a..aa02dc1de5da 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -619,6 +619,10 @@ def joint_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 0f776334eae0..8cc77ed4c148 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -456,6 +456,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index 6ee18e0d1406..8a960716258c 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -602,6 +602,10 @@ def do_classifier_free_guidance(self): def num_timesteps(self): return self._num_timesteps + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index ed3ea3e59f25..8283b57fc304 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -466,6 +466,10 @@ def num_timesteps(self): def attention_kwargs(self): return self._attention_kwargs + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt From d6ce4ab1252f4c711044293dcc9e7d27fa41f317 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 15 Jan 2025 10:35:15 +0100 Subject: [PATCH 50/61] make fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 88da30a5a7c7..cc7e1a3ca344 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -231,6 +231,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class CacheMixin(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CogVideoXTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] From 96fce86fb3eb71303bab2927be48596f72fe1b3c Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 15 Jan 2025 11:23:26 +0100 Subject: [PATCH 51/61] update --- src/diffusers/hooks/hooks.py | 23 ++++++++++++++++------- src/diffusers/models/cache_utils.py | 4 ++-- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 9d027b28cbda..05589f4cb835 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -146,12 +146,19 @@ def get_hook(self, name: str) -> Optional[ModelHook]: return None return self.hooks[name] - def remove_hook(self, name: str) -> None: - if name not in self.hooks.keys(): - raise ValueError(f"Hook with name {name} not found.") - self.hooks[name].deinitalize_hook(self._module_ref) - del self.hooks[name] - self._hook_order.remove(name) + def remove_hook(self, name: str, recurse: bool = True) -> None: + if name in self.hooks.keys(): + hook = self.hooks[name] + self._module_ref = hook.deinitalize_hook(self._module_ref) + del self.hooks[name] + self._hook_order.remove(name) + + if recurse: + for module_name, module in self._module_ref.named_modules(): + if module_name == "": + continue + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook.remove_hook(name, recurse=False) def reset_stateful_hooks(self, recurse: bool = True) -> None: for hook_name in self._hook_order: @@ -160,7 +167,9 @@ def reset_stateful_hooks(self, recurse: bool = True) -> None: hook.reset_state(self._module_ref) if recurse: - for module in self._module_ref.modules(): + for module_name, module in self._module_ref.named_modules(): + if module_name == "": + continue if hasattr(module, "_diffusers_hook"): module._diffusers_hook.reset_stateful_hooks(recurse=False) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 0cc4fc8129dd..75e491d0be95 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -41,7 +41,7 @@ def enable_cache(self, config: CacheConfig) -> None: Args: config (`Union[PyramidAttentionBroadcastConfig]`): The configuration for applying the caching technique. Currently supported caching techniques are: - - `PyramidAttentionBroadcastConfig` + - [`~hooks.PyramidAttentionBroadcastConfig`] Example: @@ -74,7 +74,7 @@ def disable_cache(self) -> None: if isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry = HookRegistry.check_if_exists_or_initialize(self) - registry.remove_hook("pyramid_attention_broadcast") + registry.remove_hook("pyramid_attention_broadcast", recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") From 107e375254961c0921aa42c799fa47d249763774 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 15 Jan 2025 11:23:44 +0100 Subject: [PATCH 52/61] improve tests --- src/diffusers/hooks/hooks.py | 2 +- tests/pipelines/test_pipelines_common.py | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 05589f4cb835..bef4c65c41e1 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -152,7 +152,7 @@ def remove_hook(self, name: str, recurse: bool = True) -> None: self._module_ref = hook.deinitalize_hook(self._module_ref) del self.hooks[name] self._hook_order.remove(name) - + if recurse: for module_name, module in self._module_ref.named_modules(): if module_name == "": diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index db4955b5099e..95d6a1b4747b 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -28,7 +28,6 @@ StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel, - apply_pyramid_attention_broadcast, ) from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.image_processor import VaeImageProcessor @@ -2337,7 +2336,7 @@ def test_pyramid_attention_broadcast_layers(self): self.pab_config.current_timestep_callback = lambda: pipe._current_timestep denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet - apply_pyramid_attention_broadcast(denoiser, self.pab_config) + denoiser.enable_cache(self.pab_config) expected_hooks = 0 if self.pab_config.spatial_attention_block_skip_range is not None: @@ -2410,15 +2409,17 @@ def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) + # Run inference without PAB inputs = self.get_dummy_inputs(device) inputs["num_inference_steps"] = 4 output = pipe(**inputs)[0] original_image_slice = output.flatten() original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:])) + # Run inference with PAB enabled self.pab_config.current_timestep_callback = lambda: pipe._current_timestep denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet - apply_pyramid_attention_broadcast(denoiser, self.pab_config) + denoiser.enable_cache(self.pab_config) inputs = self.get_dummy_inputs(device) inputs["num_inference_steps"] = 4 @@ -2426,9 +2427,21 @@ def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2) image_slice_pab_enabled = output.flatten() image_slice_pab_enabled = np.concatenate((image_slice_pab_enabled[:8], image_slice_pab_enabled[-8:])) + # Run inference with PAB disabled + denoiser.disable_cache() + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + output = pipe(**inputs)[0] + image_slice_pab_disabled = output.flatten() + image_slice_pab_disabled = np.concatenate((image_slice_pab_disabled[:8], image_slice_pab_disabled[-8:])) + assert np.allclose( original_image_slice, image_slice_pab_enabled, atol=expected_atol ), "PAB outputs should not differ much in specified timestep range." + assert np.allclose( + original_image_slice, image_slice_pab_disabled, atol=1e-4 + ), "Outputs from normal inference and after disabling cache should not differ." # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. From 40fc7a5010f988c620fa219c44b21654202de917 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 15 Jan 2025 11:30:23 +0100 Subject: [PATCH 53/61] try circular import fix --- src/diffusers/models/cache_utils.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 75e491d0be95..c02299a89dec 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -12,13 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union - -from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast - - -CacheConfig = Union[PyramidAttentionBroadcastConfig] - class CacheMixin: r""" @@ -28,13 +21,13 @@ class CacheMixin: - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) """ - _cache_config: CacheConfig = None + _cache_config = None @property def is_cache_enabled(self) -> bool: return self._cache_config is not None - def enable_cache(self, config: CacheConfig) -> None: + def enable_cache(self, config) -> None: r""" Enable caching techniques on the model. @@ -61,6 +54,8 @@ def enable_cache(self, config: CacheConfig) -> None: ``` """ + from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + if isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) else: @@ -69,6 +64,8 @@ def enable_cache(self, config: CacheConfig) -> None: self._cache_config = config def disable_cache(self) -> None: + from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig + if self._cache_config is None: raise ValueError("Caching techniques have not been enabled.") @@ -81,4 +78,6 @@ def disable_cache(self) -> None: self._cache_config = None def _reset_stateful_cache(self, recurse: bool = True) -> None: + from ..hooks import HookRegistry + HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse) From 248f1039e96c7e049a0f07072b4dcc4a1142278b Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 15 Jan 2025 11:31:03 +0100 Subject: [PATCH 54/61] apply suggestions from review --- src/diffusers/hooks/pyramid_attention_broadcast.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index d5ee4568d8bf..6d7964395929 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -221,8 +221,7 @@ def apply_pyramid_attention_broadcast( # cannot be applied to this layer. For custom layers, users can extend this functionality and implement # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`. continue - if isinstance(submodule, (Attention, MochiAttention)): - _apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config) + _apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config) def _apply_pyramid_attention_broadcast_on_attention_class( From fe939754f903833c57798826df0c8bb333a4e3a3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 16 Jan 2025 21:32:58 +0100 Subject: [PATCH 55/61] address review comments --- .../hooks/pyramid_attention_broadcast.py | 67 ++++++++++--------- src/diffusers/models/cache_utils.py | 8 ++- src/diffusers/pipelines/pipeline_utils.py | 4 +- 3 files changed, 45 insertions(+), 34 deletions(-) diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index 6d7964395929..ce6991574f71 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -137,20 +137,34 @@ class PyramidAttentionBroadcastHook(ModelHook): _is_stateful = True - def __init__(self, skip_callback: Callable[[torch.nn.Module], bool]) -> None: + def __init__( + self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int] + ) -> None: super().__init__() - self.skip_callback = skip_callback + self.timestep_skip_range = timestep_skip_range + self.block_skip_range = block_skip_range + self.current_timestep_callback = current_timestep_callback def initialize_hook(self, module): self.state = PyramidAttentionBroadcastState() return module def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: - if self.skip_callback(module): - output = self.state.cache - else: + is_within_timestep_range = ( + self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1] + ) + should_compute_attention = ( + self.state.cache is None + or self.state.iteration == 0 + or not is_within_timestep_range + or self.state.iteration % self.block_skip_range == 0 + ) + + if should_compute_attention: output = module._old_forward(*args, **kwargs) + else: + output = self.state.cache self.state.cache = output self.state.iteration += 1 @@ -266,30 +280,18 @@ def _apply_pyramid_attention_broadcast_on_attention_class( ) return False - def skip_callback(module: torch.nn.Module) -> bool: - hook: PyramidAttentionBroadcastHook = module._diffusers_hook.get_hook("pyramid_attention_broadcast") - pab_state: PyramidAttentionBroadcastState = hook.state - - if pab_state.cache is None: - return False - - is_within_timestep_range = timestep_skip_range[0] < config.current_timestep_callback() < timestep_skip_range[1] - if not is_within_timestep_range: - # We are still not in the phase of inference where skipping attention is possible without minimal quality - # loss, as described in the paper. So, the attention computation cannot be skipped - return False - - should_compute_attention = pab_state.iteration > 0 and pab_state.iteration % block_skip_range == 0 - return not should_compute_attention - logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}") - _apply_pyramid_attention_broadcast(module, skip_callback) + _apply_pyramid_attention_broadcast_hook( + module, timestep_skip_range, block_skip_range, config.current_timestep_callback + ) return True -def _apply_pyramid_attention_broadcast( +def _apply_pyramid_attention_broadcast_hook( module: Union[Attention, MochiAttention], - skip_callback: Callable[[torch.nn.Module], bool], + timestep_skip_range: Tuple[int, int], + block_skip_range: int, + current_timestep_callback: Callable[[], int], ): r""" Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module. @@ -297,13 +299,16 @@ def _apply_pyramid_attention_broadcast( Args: module (`torch.nn.Module`): The module to apply Pyramid Attention Broadcast to. - skip_callback (`Callable[[nn.Module], bool]`): - A callback function that determines whether the attention computation should be skipped or not. The - callback function should return a boolean value, where `True` indicates that the attention computation - should be skipped, and `False` indicates that the attention computation should not be skipped. The callback - function will receive a torch.nn.Module containing a `_pyramid_attention_broadcast_state` attribute that - can should be used to retrieve and update the state of PAB for the given module. + timestep_skip_range (`Tuple[int, int]`): + The range of timesteps to skip in the attention layer. The attention computations will be conditionally + skipped if the current timestep is within the specified range. + block_skip_range (`int`): + The number of times a specific attention broadcast is skipped before computing the attention states to + re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old + attention states will be re-used) before computing the new attention states again. + current_timestep_callback (`Callable[[], int]`): + A callback function that returns the current inference timestep. """ registry = HookRegistry.check_if_exists_or_initialize(module) - hook = PyramidAttentionBroadcastHook(skip_callback) + hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback) registry.register_hook(hook, "pyramid_attention_broadcast") diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index c02299a89dec..f2c621b3011a 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ..utils.logging import get_logger + + +logger = get_logger(__name__) # pylint: disable=invalid-name + class CacheMixin: r""" @@ -67,7 +72,8 @@ def disable_cache(self) -> None: from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig if self._cache_config is None: - raise ValueError("Caching techniques have not been enabled.") + logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") + return if isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry = HookRegistry.check_if_exists_or_initialize(self) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index c696c050f934..d0399bac0387 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1143,8 +1143,8 @@ def maybe_free_model_hooks(self): Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions correctly when applying `enable_model_cpu_offload`. """ - for name, component in self.components.items(): - if name in ("transformer", "unet") and hasattr(component, "_reset_stateful_cache"): + for component in self.components.values(): + if hasattr(component, "_reset_stateful_cache"): component._reset_stateful_cache() if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: From 2b5999411c6466d42530f30385dc190f531c3576 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 17 Jan 2025 13:40:11 +0530 Subject: [PATCH 56/61] Apply suggestions from code review --- tests/pipelines/test_pipelines_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 95d6a1b4747b..d1aec50ea3f7 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2334,7 +2334,7 @@ def test_pyramid_attention_broadcast_layers(self): pipe = self.pipeline_class(**components) pipe.set_progress_bar_config(disable=None) - self.pab_config.current_timestep_callback = lambda: pipe._current_timestep + self.pab_config.current_timestep_callback = lambda: pipe.current_timestep denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet denoiser.enable_cache(self.pab_config) @@ -2417,7 +2417,7 @@ def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2) original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:])) # Run inference with PAB enabled - self.pab_config.current_timestep_callback = lambda: pipe._current_timestep + self.pab_config.current_timestep_callback = lambda: pipe.current_timestep denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet denoiser.enable_cache(self.pab_config) From 8c74a7ac6379d2cc914babb4e12efadeecc12c39 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 22 Jan 2025 22:45:12 +0100 Subject: [PATCH 57/61] refactor hook implementation --- src/diffusers/hooks/__init__.py | 2 +- src/diffusers/hooks/hooks.py | 91 +++++++++++++------ .../hooks/pyramid_attention_broadcast.py | 2 +- 3 files changed, 64 insertions(+), 31 deletions(-) diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 3e12a9bec24a..e745b1320e84 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -2,6 +2,6 @@ if is_torch_available(): - from .hooks import HookRegistry + from .hooks import HookRegistry, ModelHook from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index bef4c65c41e1..f3968e853476 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +import gc from typing import Any, Dict, Optional, Tuple import torch @@ -30,6 +31,9 @@ class ModelHook: _is_stateful = False + def __init__(self): + self.fn_ref: "FunctionReference" = None + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is initialized. @@ -48,8 +52,6 @@ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: module (`torch.nn.Module`): The module attached to this hook. """ - module.forward = module._old_forward - del module._old_forward return module def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: @@ -99,6 +101,14 @@ def reset_state(self, module: torch.nn.Module): return module +class FunctionReference: + def __init__(self) -> None: + self.pre_forward = None + self.post_forward = None + self.old_forward = None + self.overwritten_forward = None + + class HookRegistry: def __init__(self, module_ref: torch.nn.Module) -> None: super().__init__() @@ -107,51 +117,68 @@ def __init__(self, module_ref: torch.nn.Module) -> None: self._module_ref = module_ref self._hook_order = [] + self._fn_refs = [] def register_hook(self, hook: ModelHook, name: str) -> None: if name in self.hooks.keys(): logger.warning(f"Hook with name {name} already exists, replacing it.") - if hasattr(self._module_ref, "_old_forward"): - old_forward = self._module_ref._old_forward - else: - old_forward = self._module_ref.forward - self._module_ref._old_forward = self._module_ref.forward - self._module_ref = hook.initialize_hook(self._module_ref) - if hasattr(hook, "new_forward"): - rewritten_forward = hook.new_forward - + def create_new_forward(function_reference: FunctionReference): def new_forward(module, *args, **kwargs): - args, kwargs = hook.pre_forward(module, *args, **kwargs) - output = rewritten_forward(module, *args, **kwargs) - return hook.post_forward(module, output) - else: + args, kwargs = function_reference.pre_forward(module, *args, **kwargs) + output = function_reference.old_forward(*args, **kwargs) + return function_reference.post_forward(module, output) - def new_forward(module, *args, **kwargs): - args, kwargs = hook.pre_forward(module, *args, **kwargs) - output = old_forward(*args, **kwargs) - return hook.post_forward(module, output) + return new_forward + + forward = self._module_ref.forward + fn_ref = FunctionReference() + fn_ref.pre_forward = hook.pre_forward + fn_ref.post_forward = hook.post_forward + fn_ref.old_forward = forward + + if hasattr(hook, "new_forward"): + fn_ref.overwritten_forward = forward + fn_ref.old_forward = functools.update_wrapper( + functools.partial(hook.new_forward, self._module_ref), hook.new_forward + ) + + rewritten_forward = create_new_forward(fn_ref) self._module_ref.forward = functools.update_wrapper( - functools.partial(new_forward, self._module_ref), old_forward + functools.partial(rewritten_forward, self._module_ref), rewritten_forward ) + hook.fn_ref = fn_ref self.hooks[name] = hook self._hook_order.append(name) + self._fn_refs.append(fn_ref) def get_hook(self, name: str) -> Optional[ModelHook]: - if name not in self.hooks.keys(): - return None - return self.hooks[name] + return self.hooks.get(name, None) def remove_hook(self, name: str, recurse: bool = True) -> None: + num_hooks = len(self._hook_order) if name in self.hooks.keys(): hook = self.hooks[name] + index = self._hook_order.index(name) + fn_ref = self._fn_refs[index] + + old_forward = fn_ref.old_forward + if fn_ref.overwritten_forward is not None: + old_forward = fn_ref.overwritten_forward + + if index == num_hooks - 1: + self._module_ref.forward = old_forward + else: + self._fn_refs[index + 1].old_forward = old_forward + self._module_ref = hook.deinitalize_hook(self._module_ref) del self.hooks[name] - self._hook_order.remove(name) + self._hook_order.pop(index) + self._fn_refs.pop(index) if recurse: for module_name, module in self._module_ref.named_modules(): @@ -160,8 +187,10 @@ def remove_hook(self, name: str, recurse: bool = True) -> None: if hasattr(module, "_diffusers_hook"): module._diffusers_hook.remove_hook(name, recurse=False) + gc.collect() + def reset_stateful_hooks(self, recurse: bool = True) -> None: - for hook_name in self._hook_order: + for hook_name in reversed(self._hook_order): hook = self.hooks[hook_name] if hook._is_stateful: hook.reset_state(self._module_ref) @@ -180,9 +209,13 @@ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry return module._diffusers_hook def __repr__(self) -> str: - hook_repr = "" + registry_repr = "" for i, hook_name in enumerate(self._hook_order): - hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" + if self.hooks[hook_name].__class__.__repr__ is not object.__repr__: + hook_repr = self.hooks[hook_name].__repr__() + else: + hook_repr = self.hooks[hook_name].__class__.__name__ + registry_repr += f" ({i}) {hook_name} - {hook_repr}" if i < len(self._hook_order) - 1: - hook_repr += "\n" - return f"HookRegistry(\n{hook_repr}\n)" + registry_repr += "\n" + return f"HookRegistry(\n{registry_repr}\n)" diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index ce6991574f71..49a75cfdc2e8 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -162,7 +162,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: ) if should_compute_attention: - output = module._old_forward(*args, **kwargs) + output = self.fn_ref.overwritten_forward(*args, **kwargs) else: output = self.state.cache From 3f3e26a3143cf8646e312e0434fd997c69840ccd Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 22 Jan 2025 22:45:31 +0100 Subject: [PATCH 58/61] add test suite for hooks --- tests/hooks/test_hooks.py | 384 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 384 insertions(+) create mode 100644 tests/hooks/test_hooks.py diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py new file mode 100644 index 000000000000..65fea530d1ef --- /dev/null +++ b/tests/hooks/test_hooks.py @@ -0,0 +1,384 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch + +from diffusers.hooks import HookRegistry, ModelHook +from diffusers.training_utils import free_memory +from diffusers.utils.logging import get_logger +from diffusers.utils.testing_utils import CaptureLogger, torch_device + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class DummyBlock(torch.nn.Module): + def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: + super().__init__() + + self.proj_in = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.proj_out = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj_in(x) + x = self.activation(x) + x = self.proj_out(x) + return x + + +class DummyModel(torch.nn.Module): + def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None: + super().__init__() + + self.linear_1 = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.blocks = torch.nn.ModuleList( + [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)] + ) + self.linear_2 = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_1(x) + x = self.activation(x) + for block in self.blocks: + x = block(x) + x = self.linear_2(x) + return x + + +class AddHook(ModelHook): + def __init__(self, value: int): + super().__init__() + self.value = value + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs): + logger.debug("AddHook pre_forward") + args = ((x + self.value) if torch.is_tensor(x) else x for x in args) + return args, kwargs + + def post_forward(self, module, output): + logger.debug("AddHook post_forward") + return output + + +class MultiplyHook(ModelHook): + def __init__(self, value: int): + super().__init__() + self.value = value + + def pre_forward(self, module, *args, **kwargs): + logger.debug("MultiplyHook pre_forward") + args = ((x * self.value) if torch.is_tensor(x) else x for x in args) + return args, kwargs + + def post_forward(self, module, output): + logger.debug("MultiplyHook post_forward") + return output + + def __repr__(self): + return f"MultiplyHook(value={self.value})" + + +class StatefulAddHook(ModelHook): + _is_stateful = True + + def __init__(self, value: int): + super().__init__() + self.value = value + self.increment = 0 + + def pre_forward(self, module, *args, **kwargs): + logger.debug("StatefulAddHook pre_forward") + add_value = self.value + self.increment + self.increment += 1 + args = ((x + add_value) if torch.is_tensor(x) else x for x in args) + return args, kwargs + + def reset_state(self, module): + self.increment = 0 + + +class SkipLayerHook(ModelHook): + def __init__(self, skip_layer: bool): + super().__init__() + self.skip_layer = skip_layer + + def pre_forward(self, module, *args, **kwargs): + logger.debug("SkipLayerHook pre_forward") + return args, kwargs + + def new_forward(self, module, *args, **kwargs): + logger.debug("SkipLayerHook new_forward") + if self.skip_layer: + return args[0] + return self.fn_ref.overwritten_forward(*args, **kwargs) + + def post_forward(self, module, output): + logger.debug("SkipLayerHook post_forward") + return output + + +class HookTests(unittest.TestCase): + in_features = 4 + hidden_features = 8 + out_features = 4 + num_layers = 2 + + def setUp(self): + params = self.get_module_parameters() + self.model = DummyModel(**params) + self.model.to(torch_device) + + def tearDown(self): + super().tearDown() + + del self.model + gc.collect() + free_memory() + + def get_module_parameters(self): + return { + "in_features": self.in_features, + "hidden_features": self.hidden_features, + "out_features": self.out_features, + "num_layers": self.num_layers, + } + + def get_generator(self): + return torch.manual_seed(0) + + def test_hook_registry(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(AddHook(1), "add_hook") + registry.register_hook(MultiplyHook(2), "multiply_hook") + + registry_repr = repr(registry) + expected_repr = ( + "HookRegistry(\n" " (0) add_hook - AddHook\n" " (1) multiply_hook - MultiplyHook(value=2)\n" ")" + ) + + self.assertEqual(len(registry.hooks), 2) + self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"]) + self.assertEqual(len(registry._fn_refs), 2) + self.assertEqual(registry_repr, expected_repr) + + registry.remove_hook("add_hook") + + self.assertEqual(len(registry.hooks), 1) + self.assertEqual(registry._hook_order, ["multiply_hook"]) + self.assertEqual(len(registry._fn_refs), 1) + + def test_stateful_hook(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(StatefulAddHook(1), "stateful_add_hook") + + self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0) + + input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) + num_repeats = 3 + + for i in range(num_repeats): + result = self.model(input) + if i == 0: + output1 = result + + self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats) + + registry.reset_stateful_hooks() + output2 = self.model(input) + + self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1) + self.assertTrue(torch.allclose(output1, output2)) + + def test_inference(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(AddHook(1), "add_hook") + registry.register_hook(MultiplyHook(2), "multiply_hook") + + input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) + output1 = self.model(input).mean().detach().cpu().item() + + registry.remove_hook("multiply_hook") + new_input = input * 2 + output2 = self.model(new_input).mean().detach().cpu().item() + + registry.remove_hook("add_hook") + new_input = input * 2 + 1 + output3 = self.model(new_input).mean().detach().cpu().item() + + self.assertAlmostEqual(output1, output2, places=5) + self.assertAlmostEqual(output1, output3, places=5) + + def test_skip_layer_hook(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") + + input = torch.zeros(1, 4, device=torch_device) + output = self.model(input).mean().detach().cpu().item() + self.assertEqual(output, 0.0) + + registry.remove_hook("skip_layer_hook") + registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook") + output = self.model(input).mean().detach().cpu().item() + self.assertNotEqual(output, 0.0) + + def test_skip_layer_internal_block(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1) + input = torch.zeros(1, 4, device=torch_device) + + registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") + with self.assertRaises(RuntimeError) as cm: + self.model(input).mean().detach().cpu().item() + self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception)) + + registry.remove_hook("skip_layer_hook") + output = self.model(input).mean().detach().cpu().item() + self.assertNotEqual(output, 0.0) + + registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1]) + registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") + output = self.model(input).mean().detach().cpu().item() + self.assertNotEqual(output, 0.0) + + def test_invocation_order_stateful_first(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(StatefulAddHook(1), "add_hook") + registry.register_hook(AddHook(2), "add_hook_2") + registry.register_hook(MultiplyHook(3), "multiply_hook") + + input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) + + logger = get_logger(__name__) + logger.setLevel("DEBUG") + + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ( + "MultiplyHook pre_forward\n" + "AddHook pre_forward\n" + "StatefulAddHook pre_forward\n" + "AddHook post_forward\n" + "MultiplyHook post_forward\n" + ) + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + registry.remove_hook("add_hook") + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ( + "MultiplyHook pre_forward\n" + "AddHook pre_forward\n" + "AddHook post_forward\n" + "MultiplyHook post_forward\n" + ) + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + def test_invocation_order_stateful_middle(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(AddHook(2), "add_hook") + registry.register_hook(StatefulAddHook(1), "add_hook_2") + registry.register_hook(MultiplyHook(3), "multiply_hook") + + input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) + + logger = get_logger(__name__) + logger.setLevel("DEBUG") + + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ( + "MultiplyHook pre_forward\n" + "StatefulAddHook pre_forward\n" + "AddHook pre_forward\n" + "AddHook post_forward\n" + "MultiplyHook post_forward\n" + ) + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + registry.remove_hook("add_hook") + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ("MultiplyHook pre_forward\nStatefulAddHook pre_forward\nMultiplyHook post_forward\n") + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + registry.remove_hook("add_hook_2") + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + def test_invocation_order_stateful_last(self): + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry.register_hook(AddHook(1), "add_hook") + registry.register_hook(MultiplyHook(2), "multiply_hook") + registry.register_hook(StatefulAddHook(3), "add_hook_2") + + input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) + + logger = get_logger(__name__) + logger.setLevel("DEBUG") + + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ( + "StatefulAddHook pre_forward\n" + "MultiplyHook pre_forward\n" + "AddHook pre_forward\n" + "AddHook post_forward\n" + "MultiplyHook post_forward\n" + ) + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) + + registry.remove_hook("add_hook") + with CaptureLogger(logger) as cap_logger: + self.model(input) + output = cap_logger.out.replace(" ", "").replace("\n", "") + expected_invocation_order_log = ( + ("StatefulAddHook pre_forward\nMultiplyHook pre_forward\nMultiplyHook post_forward\n") + .replace(" ", "") + .replace("\n", "") + ) + self.assertEqual(output, expected_invocation_order_log) From 83d221f75903581c019eef184b8b7c79b46d2da4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 28 Jan 2025 03:51:53 +0530 Subject: [PATCH 59/61] PAB Refactor (#10667) * update * update * update --------- Co-authored-by: DN6 --- src/diffusers/hooks/hooks.py | 74 ++++++++++++------- .../hooks/pyramid_attention_broadcast.py | 2 +- tests/hooks/test_hooks.py | 4 +- 3 files changed, 48 insertions(+), 32 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index f3968e853476..c1358ac201cf 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -32,7 +32,7 @@ class ModelHook: _is_stateful = False def __init__(self): - self.fn_ref: "FunctionReference" = None + self.fn_ref: "HookFunctionReference" = None def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" @@ -101,12 +101,27 @@ def reset_state(self, module: torch.nn.Module): return module -class FunctionReference: +class HookFunctionReference: def __init__(self) -> None: + """A container class that maintains mutable references to forward pass functions in a hook chain. + + Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the + entire forward pass structure. + + Attributes: + pre_forward: A callable that processes inputs before the main forward pass. + post_forward: A callable that processes outputs after the main forward pass. + forward: The current forward function in the hook chain. + original_forward: The original forward function, stored when a hook provides a custom new_forward. + + The class enables hook removal by allowing updates to the forward chain through reference modification rather + than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to + be updated, preserving the execution order of the remaining hooks. + """ self.pre_forward = None self.post_forward = None - self.old_forward = None - self.overwritten_forward = None + self.forward = None + self.original_forward = None class HookRegistry: @@ -125,24 +140,24 @@ def register_hook(self, hook: ModelHook, name: str) -> None: self._module_ref = hook.initialize_hook(self._module_ref) - def create_new_forward(function_reference: FunctionReference): + def create_new_forward(function_reference: HookFunctionReference): def new_forward(module, *args, **kwargs): args, kwargs = function_reference.pre_forward(module, *args, **kwargs) - output = function_reference.old_forward(*args, **kwargs) + output = function_reference.forward(*args, **kwargs) return function_reference.post_forward(module, output) return new_forward forward = self._module_ref.forward - fn_ref = FunctionReference() + fn_ref = HookFunctionReference() fn_ref.pre_forward = hook.pre_forward fn_ref.post_forward = hook.post_forward - fn_ref.old_forward = forward + fn_ref.forward = forward if hasattr(hook, "new_forward"): - fn_ref.overwritten_forward = forward - fn_ref.old_forward = functools.update_wrapper( + fn_ref.original_forward = forward + fn_ref.forward = functools.update_wrapper( functools.partial(hook.new_forward, self._module_ref), hook.new_forward ) @@ -160,25 +175,28 @@ def get_hook(self, name: str) -> Optional[ModelHook]: return self.hooks.get(name, None) def remove_hook(self, name: str, recurse: bool = True) -> None: - num_hooks = len(self._hook_order) - if name in self.hooks.keys(): - hook = self.hooks[name] - index = self._hook_order.index(name) - fn_ref = self._fn_refs[index] - - old_forward = fn_ref.old_forward - if fn_ref.overwritten_forward is not None: - old_forward = fn_ref.overwritten_forward + if name not in self.hooks.keys(): + logger.warning(f"hook: {name} was not found in HookRegistry") + return - if index == num_hooks - 1: - self._module_ref.forward = old_forward - else: - self._fn_refs[index + 1].old_forward = old_forward - - self._module_ref = hook.deinitalize_hook(self._module_ref) - del self.hooks[name] - self._hook_order.pop(index) - self._fn_refs.pop(index) + num_hooks = len(self._hook_order) + hook = self.hooks[name] + index = self._hook_order.index(name) + fn_ref = self._fn_refs[index] + + old_forward = fn_ref.forward + if fn_ref.original_forward is not None: + old_forward = fn_ref.original_forward + + if index == num_hooks - 1: + self._module_ref.forward = old_forward + else: + self._fn_refs[index + 1].forward = old_forward + + self._module_ref = hook.deinitalize_hook(self._module_ref) + del self.hooks[name] + self._hook_order.pop(index) + self._fn_refs.pop(index) if recurse: for module_name, module in self._module_ref.named_modules(): diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index 49a75cfdc2e8..9f8597d52f8c 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -162,7 +162,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: ) if should_compute_attention: - output = self.fn_ref.overwritten_forward(*args, **kwargs) + output = self.fn_ref.original_forward(*args, **kwargs) else: output = self.state.cache diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py index 65fea530d1ef..74bd43c52315 100644 --- a/tests/hooks/test_hooks.py +++ b/tests/hooks/test_hooks.py @@ -126,7 +126,7 @@ def new_forward(self, module, *args, **kwargs): logger.debug("SkipLayerHook new_forward") if self.skip_layer: return args[0] - return self.fn_ref.overwritten_forward(*args, **kwargs) + return self.fn_ref.original_forward(*args, **kwargs) def post_forward(self, module, output): logger.debug("SkipLayerHook post_forward") @@ -174,14 +174,12 @@ def test_hook_registry(self): self.assertEqual(len(registry.hooks), 2) self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"]) - self.assertEqual(len(registry._fn_refs), 2) self.assertEqual(registry_repr, expected_repr) registry.remove_hook("add_hook") self.assertEqual(len(registry.hooks), 1) self.assertEqual(registry._hook_order, ["multiply_hook"]) - self.assertEqual(len(registry._fn_refs), 1) def test_stateful_hook(self): registry = HookRegistry.check_if_exists_or_initialize(self.model) From 3d269ad81fa2b48b8c3613f5a6ac5824ec6aacb6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 27 Jan 2025 23:24:49 +0100 Subject: [PATCH 60/61] update --- src/diffusers/hooks/hooks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index c1358ac201cf..576f4fc944ad 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -13,7 +13,6 @@ # limitations under the License. import functools -import gc from typing import Any, Dict, Optional, Tuple import torch @@ -136,7 +135,10 @@ def __init__(self, module_ref: torch.nn.Module) -> None: def register_hook(self, hook: ModelHook, name: str) -> None: if name in self.hooks.keys(): - logger.warning(f"Hook with name {name} already exists, replacing it.") + raise ValueError( + f"Hook with name {name} already exists in the registry. Please use a different name or " + f"first remove the existing hook and then add a new one." + ) self._module_ref = hook.initialize_hook(self._module_ref) @@ -205,8 +207,6 @@ def remove_hook(self, name: str, recurse: bool = True) -> None: if hasattr(module, "_diffusers_hook"): module._diffusers_hook.remove_hook(name, recurse=False) - gc.collect() - def reset_stateful_hooks(self, recurse: bool = True) -> None: for hook_name in reversed(self._hook_order): hook = self.hooks[hook_name] From 5535fd69051c30fc1880da9f82d4912ce6860bdc Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 28 Jan 2025 00:17:22 +0100 Subject: [PATCH 61/61] fix remove hook behaviour --- src/diffusers/hooks/hooks.py | 41 +++++++++++++++++------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 576f4fc944ad..3b2e4ed91c2f 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -177,28 +177,25 @@ def get_hook(self, name: str) -> Optional[ModelHook]: return self.hooks.get(name, None) def remove_hook(self, name: str, recurse: bool = True) -> None: - if name not in self.hooks.keys(): - logger.warning(f"hook: {name} was not found in HookRegistry") - return - - num_hooks = len(self._hook_order) - hook = self.hooks[name] - index = self._hook_order.index(name) - fn_ref = self._fn_refs[index] - - old_forward = fn_ref.forward - if fn_ref.original_forward is not None: - old_forward = fn_ref.original_forward - - if index == num_hooks - 1: - self._module_ref.forward = old_forward - else: - self._fn_refs[index + 1].forward = old_forward - - self._module_ref = hook.deinitalize_hook(self._module_ref) - del self.hooks[name] - self._hook_order.pop(index) - self._fn_refs.pop(index) + if name in self.hooks.keys(): + num_hooks = len(self._hook_order) + hook = self.hooks[name] + index = self._hook_order.index(name) + fn_ref = self._fn_refs[index] + + old_forward = fn_ref.forward + if fn_ref.original_forward is not None: + old_forward = fn_ref.original_forward + + if index == num_hooks - 1: + self._module_ref.forward = old_forward + else: + self._fn_refs[index + 1].forward = old_forward + + self._module_ref = hook.deinitalize_hook(self._module_ref) + del self.hooks[name] + self._hook_order.pop(index) + self._fn_refs.pop(index) if recurse: for module_name, module in self._module_ref.named_modules():