Skip to content

Commit 14f7b54

Browse files
authored
[Hunyuan DiT] feat: enable fusing qkv projections when doing attention (huggingface#8396)
* feat: introduce qkv fusion for Hunyuan * fix copies
1 parent 07cd200 commit 14f7b54

File tree

2 files changed

+140
-2
lines changed

2 files changed

+140
-2
lines changed

src/diffusers/models/transformers/hunyuan_transformer_2d.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Optional
14+
from typing import Dict, Optional, Union
1515

1616
import torch
1717
import torch.nn.functional as F
@@ -21,7 +21,7 @@
2121
from ...utils import logging
2222
from ...utils.torch_utils import maybe_allow_in_graph
2323
from ..attention import FeedForward
24-
from ..attention_processor import Attention, HunyuanAttnProcessor2_0
24+
from ..attention_processor import Attention, AttentionProcessor, HunyuanAttnProcessor2_0
2525
from ..embeddings import (
2626
HunyuanCombinedTimestepTextSizeStyleEmbedding,
2727
PatchEmbed,
@@ -321,6 +321,110 @@ def __init__(
321321
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
322322
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
323323

324+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
325+
def fuse_qkv_projections(self):
326+
"""
327+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
328+
are fused. For cross-attention modules, key and value projection matrices are fused.
329+
330+
<Tip warning={true}>
331+
332+
This API is 🧪 experimental.
333+
334+
</Tip>
335+
"""
336+
self.original_attn_processors = None
337+
338+
for _, attn_processor in self.attn_processors.items():
339+
if "Added" in str(attn_processor.__class__.__name__):
340+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
341+
342+
self.original_attn_processors = self.attn_processors
343+
344+
for module in self.modules():
345+
if isinstance(module, Attention):
346+
module.fuse_projections(fuse=True)
347+
348+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
349+
def unfuse_qkv_projections(self):
350+
"""Disables the fused QKV projection if enabled.
351+
352+
<Tip warning={true}>
353+
354+
This API is 🧪 experimental.
355+
356+
</Tip>
357+
358+
"""
359+
if self.original_attn_processors is not None:
360+
self.set_attn_processor(self.original_attn_processors)
361+
362+
@property
363+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
364+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
365+
r"""
366+
Returns:
367+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
368+
indexed by its weight name.
369+
"""
370+
# set recursively
371+
processors = {}
372+
373+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
374+
if hasattr(module, "get_processor"):
375+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
376+
377+
for sub_name, child in module.named_children():
378+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
379+
380+
return processors
381+
382+
for name, module in self.named_children():
383+
fn_recursive_add_processors(name, module, processors)
384+
385+
return processors
386+
387+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
388+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
389+
r"""
390+
Sets the attention processor to use to compute attention.
391+
392+
Parameters:
393+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
394+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
395+
for **all** `Attention` layers.
396+
397+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
398+
processor. This is strongly recommended when setting trainable attention processors.
399+
400+
"""
401+
count = len(self.attn_processors.keys())
402+
403+
if isinstance(processor, dict) and len(processor) != count:
404+
raise ValueError(
405+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
406+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
407+
)
408+
409+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
410+
if hasattr(module, "set_processor"):
411+
if not isinstance(processor, dict):
412+
module.set_processor(processor)
413+
else:
414+
module.set_processor(processor.pop(f"{name}.processor"))
415+
416+
for sub_name, child in module.named_children():
417+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
418+
419+
for name, module in self.named_children():
420+
fn_recursive_attn_processor(name, module, processor)
421+
422+
def set_default_attn_processor(self):
423+
"""
424+
Disables custom attention processors and sets the default attention implementation.
425+
"""
426+
self.set_attn_processor(HunyuanAttnProcessor2_0())
427+
324428
def forward(
325429
self,
326430
hidden_states,

tests/pipelines/hunyuan_dit/test_hunyuan_dit.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,40 @@ def test_save_load_optional_components(self):
228228
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
229229
self.assertLess(max_diff, 1e-4)
230230

231+
def test_fused_qkv_projections(self):
232+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
233+
components = self.get_dummy_components()
234+
pipe = self.pipeline_class(**components)
235+
pipe = pipe.to(device)
236+
pipe.set_progress_bar_config(disable=None)
237+
238+
inputs = self.get_dummy_inputs(device)
239+
inputs["return_dict"] = False
240+
image = pipe(**inputs)[0]
241+
original_image_slice = image[0, -3:, -3:, -1]
242+
243+
pipe.transformer.fuse_qkv_projections()
244+
inputs = self.get_dummy_inputs(device)
245+
inputs["return_dict"] = False
246+
image_fused = pipe(**inputs)[0]
247+
image_slice_fused = image_fused[0, -3:, -3:, -1]
248+
249+
pipe.transformer.unfuse_qkv_projections()
250+
inputs = self.get_dummy_inputs(device)
251+
inputs["return_dict"] = False
252+
image_disabled = pipe(**inputs)[0]
253+
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
254+
255+
assert np.allclose(
256+
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
257+
), "Fusion of QKV projections shouldn't affect the outputs."
258+
assert np.allclose(
259+
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
260+
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
261+
assert np.allclose(
262+
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
263+
), "Original outputs should match when fused QKV projections are disabled."
264+
231265

232266
@slow
233267
@require_torch_gpu

0 commit comments

Comments
 (0)