Skip to content

Commit 0d1a1f8

Browse files
authored
Merge branch 'main' into layerwise-upcasting
2 parents f1fa123 + e780c05 commit 0d1a1f8

File tree

4 files changed

+19
-8
lines changed

4 files changed

+19
-8
lines changed

src/diffusers/models/transformers/pixart_transformer_2d.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ...configuration_utils import ConfigMixin, register_to_config
2020
from ...utils import is_torch_version, logging
2121
from ..attention import BasicTransformerBlock
22-
from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
22+
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
2323
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
2424
from ..modeling_outputs import Transformer2DModelOutput
2525
from ..modeling_utils import ModelMixin
@@ -248,6 +248,14 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
248248
for name, module in self.named_children():
249249
fn_recursive_attn_processor(name, module, processor)
250250

251+
def set_default_attn_processor(self):
252+
"""
253+
Disables custom attention processors and sets the default attention implementation.
254+
255+
Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
256+
"""
257+
self.set_attn_processor(AttnProcessor())
258+
251259
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
252260
def fuse_qkv_projections(self):
253261
"""

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
251251
"""
252252

253253
_supports_gradient_checkpointing = True
254+
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
254255

255256
@register_to_config
256257
def __init__(

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,13 @@ def __call__(
677677
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
678678
self._num_timesteps = len(timesteps)
679679

680+
# handle guidance
681+
if self.transformer.config.guidance_embeds:
682+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
683+
guidance = guidance.expand(latents.shape[0])
684+
else:
685+
guidance = None
686+
680687
# 6. Denoising loop
681688
with self.progress_bar(total=num_inference_steps) as progress_bar:
682689
for i, t in enumerate(timesteps):
@@ -686,13 +693,6 @@ def __call__(
686693
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
687694
timestep = t.expand(latents.shape[0]).to(latents.dtype)
688695

689-
# handle guidance
690-
if self.transformer.config.guidance_embeds:
691-
guidance = torch.tensor([guidance_scale], device=device)
692-
guidance = guidance.expand(latents.shape[0])
693-
else:
694-
guidance = None
695-
696696
noise_pred = self.transformer(
697697
hidden_states=latents,
698698
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
3030
model_class = FluxTransformer2DModel
3131
main_input_name = "hidden_states"
32+
# We override the items here because the transformer under consideration is small.
33+
model_split_percents = [0.7, 0.6, 0.6]
3234

3335
@property
3436
def dummy_input(self):

0 commit comments

Comments
 (0)