Skip to content

Commit e780c05

Browse files
authored
[Chore] add set_default_attn_processor to pixart. (#9196)
add set_default_attn_processor to pixart.
1 parent e649678 commit e780c05

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/diffusers/models/transformers/pixart_transformer_2d.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ...configuration_utils import ConfigMixin, register_to_config
2020
from ...utils import is_torch_version, logging
2121
from ..attention import BasicTransformerBlock
22-
from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
22+
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
2323
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
2424
from ..modeling_outputs import Transformer2DModelOutput
2525
from ..modeling_utils import ModelMixin
@@ -247,6 +247,14 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
247247
for name, module in self.named_children():
248248
fn_recursive_attn_processor(name, module, processor)
249249

250+
def set_default_attn_processor(self):
251+
"""
252+
Disables custom attention processors and sets the default attention implementation.
253+
254+
Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
255+
"""
256+
self.set_attn_processor(AttnProcessor())
257+
250258
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
251259
def fuse_qkv_projections(self):
252260
"""

0 commit comments

Comments
 (0)