diff --git a/src/diffusers/models/unets/unet_kandinsky3.py b/src/diffusers/models/unets/unet_kandinsky3.py index 73bf0020b481..23e6be0f11e5 100644 --- a/src/diffusers/models/unets/unet_kandinsky3.py +++ b/src/diffusers/models/unets/unet_kandinsky3.py @@ -395,12 +395,13 @@ def __init__(self, groups, normalized_shape, context_dim): self.context_mlp[1].bias.data.zero_() def forward(self, x, context): - context = self.context_mlp(context) - - for _ in range(len(x.shape[2:])): - context = context.unsqueeze(-1) - - scale, shift = context.chunk(2, dim=1) + context_out = self.context_mlp(context) + # Expand context_out for broadcasting in a single reshape + # context_out: (B, 2*C) --> (B, 2*C, 1, 1, ..., 1) for broadcasting + # Target shape: [batch, 2*C] + [1] * (ndim - 2) + target_shape = list(context_out.shape) + [1] * (x.dim() - 2) + context_out = context_out.view(*target_shape) + scale, shift = context_out.chunk(2, dim=1) x = self.norm(x) * (scale + 1.0) + shift return x