Skip to content

Commit a8ad666

Browse files
authored
[Hunyuan] feat: support chunked ff. (huggingface#8397)
feat: support chunked ff.
1 parent 14f7b54 commit a8ad666

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

src/diffusers/models/transformers/hunyuan_transformer_2d.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def __init__(
166166
self._chunk_size = None
167167
self._chunk_dim = 0
168168

169+
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
169170
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
170171
# Sets chunk feed-forward
171172
self._chunk_size = chunk_size
@@ -529,3 +530,45 @@ def forward(
529530
if not return_dict:
530531
return (output,)
531532
return Transformer2DModelOutput(sample=output)
533+
534+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
535+
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
536+
"""
537+
Sets the attention processor to use [feed forward
538+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
539+
540+
Parameters:
541+
chunk_size (`int`, *optional*):
542+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
543+
over each tensor of dim=`dim`.
544+
dim (`int`, *optional*, defaults to `0`):
545+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
546+
or dim=1 (sequence length).
547+
"""
548+
if dim not in [0, 1]:
549+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
550+
551+
# By default chunk size is 1
552+
chunk_size = chunk_size or 1
553+
554+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
555+
if hasattr(module, "set_chunk_feed_forward"):
556+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
557+
558+
for child in module.children():
559+
fn_recursive_feed_forward(child, chunk_size, dim)
560+
561+
for module in self.children():
562+
fn_recursive_feed_forward(module, chunk_size, dim)
563+
564+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
565+
def disable_forward_chunking(self):
566+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
567+
if hasattr(module, "set_chunk_feed_forward"):
568+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
569+
570+
for child in module.children():
571+
fn_recursive_feed_forward(child, chunk_size, dim)
572+
573+
for module in self.children():
574+
fn_recursive_feed_forward(module, None, 0)

tests/pipelines/hunyuan_dit/test_hunyuan_dit.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,26 @@ def test_save_load_optional_components(self):
228228
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
229229
self.assertLess(max_diff, 1e-4)
230230

231+
def test_feed_forward_chunking(self):
232+
device = "cpu"
233+
234+
components = self.get_dummy_components()
235+
pipe = self.pipeline_class(**components)
236+
pipe.to(device)
237+
pipe.set_progress_bar_config(disable=None)
238+
239+
inputs = self.get_dummy_inputs(device)
240+
image = pipe(**inputs).images
241+
image_slice_no_chunking = image[0, -3:, -3:, -1]
242+
243+
pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0)
244+
inputs = self.get_dummy_inputs(device)
245+
image = pipe(**inputs).images
246+
image_slice_chunking = image[0, -3:, -3:, -1]
247+
248+
max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max()
249+
self.assertLess(max_diff, 1e-4)
250+
231251
def test_fused_qkv_projections(self):
232252
device = "cpu" # ensure determinism for the device-dependent torch.Generator
233253
components = self.get_dummy_components()

0 commit comments

Comments
 (0)