@@ -166,6 +166,7 @@ def __init__(
166
166
self ._chunk_size = None
167
167
self ._chunk_dim = 0
168
168
169
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
169
170
def set_chunk_feed_forward (self , chunk_size : Optional [int ], dim : int = 0 ):
170
171
# Sets chunk feed-forward
171
172
self ._chunk_size = chunk_size
@@ -529,3 +530,45 @@ def forward(
529
530
if not return_dict :
530
531
return (output ,)
531
532
return Transformer2DModelOutput (sample = output )
533
+
534
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
535
+ def enable_forward_chunking (self , chunk_size : Optional [int ] = None , dim : int = 0 ) -> None :
536
+ """
537
+ Sets the attention processor to use [feed forward
538
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
539
+
540
+ Parameters:
541
+ chunk_size (`int`, *optional*):
542
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
543
+ over each tensor of dim=`dim`.
544
+ dim (`int`, *optional*, defaults to `0`):
545
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
546
+ or dim=1 (sequence length).
547
+ """
548
+ if dim not in [0 , 1 ]:
549
+ raise ValueError (f"Make sure to set `dim` to either 0 or 1, not { dim } " )
550
+
551
+ # By default chunk size is 1
552
+ chunk_size = chunk_size or 1
553
+
554
+ def fn_recursive_feed_forward (module : torch .nn .Module , chunk_size : int , dim : int ):
555
+ if hasattr (module , "set_chunk_feed_forward" ):
556
+ module .set_chunk_feed_forward (chunk_size = chunk_size , dim = dim )
557
+
558
+ for child in module .children ():
559
+ fn_recursive_feed_forward (child , chunk_size , dim )
560
+
561
+ for module in self .children ():
562
+ fn_recursive_feed_forward (module , chunk_size , dim )
563
+
564
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
565
+ def disable_forward_chunking (self ):
566
+ def fn_recursive_feed_forward (module : torch .nn .Module , chunk_size : int , dim : int ):
567
+ if hasattr (module , "set_chunk_feed_forward" ):
568
+ module .set_chunk_feed_forward (chunk_size = chunk_size , dim = dim )
569
+
570
+ for child in module .children ():
571
+ fn_recursive_feed_forward (child , chunk_size , dim )
572
+
573
+ for module in self .children ():
574
+ fn_recursive_feed_forward (module , None , 0 )
0 commit comments