@@ -217,15 +217,13 @@ def optimized_attention(qkv, num_heads):
217
217
218
218
219
219
class SelfAttention (nn .Module ):
220
- ATTENTION_MODES = ("xformers" , "torch" , "torch-hb" , "math" , "debug" )
221
220
222
221
def __init__ (
223
222
self ,
224
223
dim : int ,
225
224
num_heads : int = 8 ,
226
225
qkv_bias : bool = False ,
227
226
qk_scale : Optional [float ] = None ,
228
- attn_mode : str = "xformers" ,
229
227
pre_only : bool = False ,
230
228
qk_norm : Optional [str ] = None ,
231
229
rmsnorm : bool = False ,
@@ -239,8 +237,6 @@ def __init__(
239
237
self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias , dtype = dtype , device = device )
240
238
if not pre_only :
241
239
self .proj = nn .Linear (dim , dim , dtype = dtype , device = device )
242
- assert attn_mode in self .ATTENTION_MODES
243
- self .attn_mode = attn_mode
244
240
self .pre_only = pre_only
245
241
246
242
if qk_norm == "rms" :
@@ -294,7 +290,7 @@ def post_attention(self, x: torch.Tensor) -> torch.Tensor:
294
290
295
291
def forward (self , x : torch .Tensor ) -> torch .Tensor :
296
292
(q , k , v ) = self .pre_attention (x )
297
- x = attention (q , k , v , self .num_heads , self . attn_mode )
293
+ x = attention (q , k , v , self .num_heads )
298
294
x = self .post_attention (x )
299
295
return x
300
296
@@ -391,14 +387,11 @@ def forward(self, x):
391
387
class DismantledBlock (nn .Module ):
392
388
"""A DiT block with gated adaptive layer norm (adaLN) conditioning."""
393
389
394
- ATTENTION_MODES = ("xformers" , "torch" , "torch-hb" , "math" , "debug" )
395
-
396
390
def __init__ (
397
391
self ,
398
392
hidden_size : int ,
399
393
num_heads : int ,
400
394
mlp_ratio : float = 4.0 ,
401
- attn_mode : str = "xformers" ,
402
395
qkv_bias : bool = False ,
403
396
pre_only : bool = False ,
404
397
rmsnorm : bool = False ,
@@ -411,7 +404,6 @@ def __init__(
411
404
** block_kwargs ,
412
405
):
413
406
super ().__init__ ()
414
- assert attn_mode in self .ATTENTION_MODES
415
407
if not rmsnorm :
416
408
self .norm1 = nn .LayerNorm (
417
409
hidden_size ,
@@ -426,7 +418,6 @@ def __init__(
426
418
dim = hidden_size ,
427
419
num_heads = num_heads ,
428
420
qkv_bias = qkv_bias ,
429
- attn_mode = attn_mode ,
430
421
pre_only = pre_only ,
431
422
qk_norm = qk_norm ,
432
423
rmsnorm = rmsnorm ,
@@ -441,7 +432,6 @@ def __init__(
441
432
dim = hidden_size ,
442
433
num_heads = num_heads ,
443
434
qkv_bias = qkv_bias ,
444
- attn_mode = attn_mode ,
445
435
pre_only = False ,
446
436
qk_norm = qk_norm ,
447
437
rmsnorm = rmsnorm ,
@@ -716,7 +706,6 @@ def __init__(
716
706
adm_in_channels : Optional [int ] = None ,
717
707
context_embedder_config : Optional [Dict ] = None ,
718
708
register_length : int = 0 ,
719
- attn_mode : str = "torch" ,
720
709
rmsnorm : bool = False ,
721
710
scale_mod_only : bool = False ,
722
711
swiglu : bool = False ,
@@ -735,7 +724,7 @@ def __init__(
735
724
super ().__init__ ()
736
725
if verbose :
737
726
print (
738
- f"mmdit initializing with: { input_size = } , { patch_size = } , { in_channels = } , { depth = } , { mlp_ratio = } , { learn_sigma = } , { adm_in_channels = } , { context_embedder_config = } , { register_length = } , { attn_mode = } , { rmsnorm = } , { scale_mod_only = } , { swiglu = } , { out_channels = } , { pos_embed_scaling_factor = } , { pos_embed_offset = } , { pos_embed_max_size = } , { num_patches = } , { qk_norm = } , { qkv_bias = } , { dtype = } , { device = } "
727
+ f"mmdit initializing with: { input_size = } , { patch_size = } , { in_channels = } , { depth = } , { mlp_ratio = } , { learn_sigma = } , { adm_in_channels = } , { context_embedder_config = } , { register_length = } , { rmsnorm = } , { scale_mod_only = } , { swiglu = } , { out_channels = } , { pos_embed_scaling_factor = } , { pos_embed_offset = } , { pos_embed_max_size = } , { num_patches = } , { qk_norm = } , { qkv_bias = } , { dtype = } , { device = } "
739
728
)
740
729
self .dtype = dtype
741
730
self .learn_sigma = learn_sigma
@@ -805,7 +794,6 @@ def __init__(
805
794
num_heads ,
806
795
mlp_ratio = mlp_ratio ,
807
796
qkv_bias = qkv_bias ,
808
- attn_mode = attn_mode ,
809
797
pre_only = i == depth - 1 ,
810
798
rmsnorm = rmsnorm ,
811
799
scale_mod_only = scale_mod_only ,
0 commit comments