Skip to content

Commit 56e090d

Browse files
remove attn_mode and xformers
1 parent 7df9edb commit 56e090d

File tree

2 files changed

+7
-31
lines changed

2 files changed

+7
-31
lines changed

mmditx.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,13 @@ def optimized_attention(qkv, num_heads):
217217

218218

219219
class SelfAttention(nn.Module):
220-
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
221220

222221
def __init__(
223222
self,
224223
dim: int,
225224
num_heads: int = 8,
226225
qkv_bias: bool = False,
227226
qk_scale: Optional[float] = None,
228-
attn_mode: str = "xformers",
229227
pre_only: bool = False,
230228
qk_norm: Optional[str] = None,
231229
rmsnorm: bool = False,
@@ -239,8 +237,6 @@ def __init__(
239237
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
240238
if not pre_only:
241239
self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
242-
assert attn_mode in self.ATTENTION_MODES
243-
self.attn_mode = attn_mode
244240
self.pre_only = pre_only
245241

246242
if qk_norm == "rms":
@@ -294,7 +290,7 @@ def post_attention(self, x: torch.Tensor) -> torch.Tensor:
294290

295291
def forward(self, x: torch.Tensor) -> torch.Tensor:
296292
(q, k, v) = self.pre_attention(x)
297-
x = attention(q, k, v, self.num_heads, self.attn_mode)
293+
x = attention(q, k, v, self.num_heads)
298294
x = self.post_attention(x)
299295
return x
300296

@@ -391,14 +387,11 @@ def forward(self, x):
391387
class DismantledBlock(nn.Module):
392388
"""A DiT block with gated adaptive layer norm (adaLN) conditioning."""
393389

394-
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
395-
396390
def __init__(
397391
self,
398392
hidden_size: int,
399393
num_heads: int,
400394
mlp_ratio: float = 4.0,
401-
attn_mode: str = "xformers",
402395
qkv_bias: bool = False,
403396
pre_only: bool = False,
404397
rmsnorm: bool = False,
@@ -411,7 +404,6 @@ def __init__(
411404
**block_kwargs,
412405
):
413406
super().__init__()
414-
assert attn_mode in self.ATTENTION_MODES
415407
if not rmsnorm:
416408
self.norm1 = nn.LayerNorm(
417409
hidden_size,
@@ -426,7 +418,6 @@ def __init__(
426418
dim=hidden_size,
427419
num_heads=num_heads,
428420
qkv_bias=qkv_bias,
429-
attn_mode=attn_mode,
430421
pre_only=pre_only,
431422
qk_norm=qk_norm,
432423
rmsnorm=rmsnorm,
@@ -441,7 +432,6 @@ def __init__(
441432
dim=hidden_size,
442433
num_heads=num_heads,
443434
qkv_bias=qkv_bias,
444-
attn_mode=attn_mode,
445435
pre_only=False,
446436
qk_norm=qk_norm,
447437
rmsnorm=rmsnorm,
@@ -716,7 +706,6 @@ def __init__(
716706
adm_in_channels: Optional[int] = None,
717707
context_embedder_config: Optional[Dict] = None,
718708
register_length: int = 0,
719-
attn_mode: str = "torch",
720709
rmsnorm: bool = False,
721710
scale_mod_only: bool = False,
722711
swiglu: bool = False,
@@ -735,7 +724,7 @@ def __init__(
735724
super().__init__()
736725
if verbose:
737726
print(
738-
f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {attn_mode=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}"
727+
f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}"
739728
)
740729
self.dtype = dtype
741730
self.learn_sigma = learn_sigma
@@ -805,7 +794,6 @@ def __init__(
805794
num_heads,
806795
mlp_ratio=mlp_ratio,
807796
qkv_bias=qkv_bias,
808-
attn_mode=attn_mode,
809797
pre_only=i == depth - 1,
810798
rmsnorm=rmsnorm,
811799
scale_mod_only=scale_mod_only,

other_impls.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,20 @@
99
from transformers import CLIPTokenizer, T5TokenizerFast
1010
from einops import rearrange
1111

12-
try:
13-
import xformers.ops
14-
except ImportError:
15-
xformers.ops = None
16-
print("xformers not found, attn_mode='xformers' will not work")
17-
1812
#################################################################################################
1913
### Core/Utility
2014
#################################################################################################
2115

2216

23-
def attention(q, k, v, heads, mask=None, attn_mode: str = "torch"):
17+
def attention(q, k, v, heads, mask=None):
2418
"""Convenience wrapper around a basic attention operation"""
2519
b, _, dim_head = q.shape
2620
dim_head //= heads
2721
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
28-
if attn_mode == "torch":
29-
out = torch.nn.functional.scaled_dot_product_attention(
30-
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
31-
)
32-
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
33-
elif attn_mode == "xformers":
34-
x = xformers.ops.memory_efficient_attention(q, k, v)
35-
x = rearrange(x, "b h n d -> b n (h d)")
36-
return x
37-
22+
out = torch.nn.functional.scaled_dot_product_attention(
23+
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
24+
)
25+
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
3826

3927
class Mlp(nn.Module):
4028
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""

0 commit comments

Comments
 (0)