|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 |
| -from typing import Optional |
| 14 | +from typing import Dict, Optional, Union |
15 | 15 |
|
16 | 16 | import torch
|
17 | 17 | import torch.nn.functional as F
|
|
21 | 21 | from ...utils import logging
|
22 | 22 | from ...utils.torch_utils import maybe_allow_in_graph
|
23 | 23 | from ..attention import FeedForward
|
24 |
| -from ..attention_processor import Attention, HunyuanAttnProcessor2_0 |
| 24 | +from ..attention_processor import Attention, AttentionProcessor, HunyuanAttnProcessor2_0 |
25 | 25 | from ..embeddings import (
|
26 | 26 | HunyuanCombinedTimestepTextSizeStyleEmbedding,
|
27 | 27 | PatchEmbed,
|
@@ -321,6 +321,110 @@ def __init__(
|
321 | 321 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
322 | 322 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
323 | 323 |
|
| 324 | + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections |
| 325 | + def fuse_qkv_projections(self): |
| 326 | + """ |
| 327 | + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) |
| 328 | + are fused. For cross-attention modules, key and value projection matrices are fused. |
| 329 | +
|
| 330 | + <Tip warning={true}> |
| 331 | +
|
| 332 | + This API is 🧪 experimental. |
| 333 | +
|
| 334 | + </Tip> |
| 335 | + """ |
| 336 | + self.original_attn_processors = None |
| 337 | + |
| 338 | + for _, attn_processor in self.attn_processors.items(): |
| 339 | + if "Added" in str(attn_processor.__class__.__name__): |
| 340 | + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") |
| 341 | + |
| 342 | + self.original_attn_processors = self.attn_processors |
| 343 | + |
| 344 | + for module in self.modules(): |
| 345 | + if isinstance(module, Attention): |
| 346 | + module.fuse_projections(fuse=True) |
| 347 | + |
| 348 | + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections |
| 349 | + def unfuse_qkv_projections(self): |
| 350 | + """Disables the fused QKV projection if enabled. |
| 351 | +
|
| 352 | + <Tip warning={true}> |
| 353 | +
|
| 354 | + This API is 🧪 experimental. |
| 355 | +
|
| 356 | + </Tip> |
| 357 | +
|
| 358 | + """ |
| 359 | + if self.original_attn_processors is not None: |
| 360 | + self.set_attn_processor(self.original_attn_processors) |
| 361 | + |
| 362 | + @property |
| 363 | + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors |
| 364 | + def attn_processors(self) -> Dict[str, AttentionProcessor]: |
| 365 | + r""" |
| 366 | + Returns: |
| 367 | + `dict` of attention processors: A dictionary containing all attention processors used in the model with |
| 368 | + indexed by its weight name. |
| 369 | + """ |
| 370 | + # set recursively |
| 371 | + processors = {} |
| 372 | + |
| 373 | + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
| 374 | + if hasattr(module, "get_processor"): |
| 375 | + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) |
| 376 | + |
| 377 | + for sub_name, child in module.named_children(): |
| 378 | + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
| 379 | + |
| 380 | + return processors |
| 381 | + |
| 382 | + for name, module in self.named_children(): |
| 383 | + fn_recursive_add_processors(name, module, processors) |
| 384 | + |
| 385 | + return processors |
| 386 | + |
| 387 | + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor |
| 388 | + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
| 389 | + r""" |
| 390 | + Sets the attention processor to use to compute attention. |
| 391 | +
|
| 392 | + Parameters: |
| 393 | + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): |
| 394 | + The instantiated processor class or a dictionary of processor classes that will be set as the processor |
| 395 | + for **all** `Attention` layers. |
| 396 | +
|
| 397 | + If `processor` is a dict, the key needs to define the path to the corresponding cross attention |
| 398 | + processor. This is strongly recommended when setting trainable attention processors. |
| 399 | +
|
| 400 | + """ |
| 401 | + count = len(self.attn_processors.keys()) |
| 402 | + |
| 403 | + if isinstance(processor, dict) and len(processor) != count: |
| 404 | + raise ValueError( |
| 405 | + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
| 406 | + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
| 407 | + ) |
| 408 | + |
| 409 | + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
| 410 | + if hasattr(module, "set_processor"): |
| 411 | + if not isinstance(processor, dict): |
| 412 | + module.set_processor(processor) |
| 413 | + else: |
| 414 | + module.set_processor(processor.pop(f"{name}.processor")) |
| 415 | + |
| 416 | + for sub_name, child in module.named_children(): |
| 417 | + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
| 418 | + |
| 419 | + for name, module in self.named_children(): |
| 420 | + fn_recursive_attn_processor(name, module, processor) |
| 421 | + |
| 422 | + def set_default_attn_processor(self): |
| 423 | + """ |
| 424 | + Disables custom attention processors and sets the default attention implementation. |
| 425 | + """ |
| 426 | + self.set_attn_processor(HunyuanAttnProcessor2_0()) |
| 427 | + |
324 | 428 | def forward(
|
325 | 429 | self,
|
326 | 430 | hidden_states,
|
|
0 commit comments