Skip to content

Commit 2f6f426

Browse files
authored
[Hunyuan] allow Hunyuan DiT to run under 6GB for GPU VRAM (huggingface#8399)
* allow hunyuan dit to run under 6GB for GPU VRAM * add section in the docs/
1 parent a0542c1 commit 2f6f426

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

docs/source/en/api/pipelines/hunyuandit.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ HunyuanDiT has the following components:
2929
* It combines two text encoders, a bilingual CLIP and a multilingual T5 encoder
3030

3131

32+
## Memory optimization
33+
34+
By loading the T5 text encoder in 8 bits, you can run the pipeline in just under 6 GBs of GPU VRAM. Refer to [this script](https://gist.github.com/sayakpaul/3154605f6af05b98a41081aaba5ca43e) for details.
35+
3236
## HunyuanDiTPipeline
3337

3438
[[autodoc]] HunyuanDiTPipeline

src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,16 +228,22 @@ def __init__(
228228
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
229229
)
230230

231-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
231+
self.vae_scale_factor = (
232+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
233+
)
232234
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
233235
self.register_to_config(requires_safety_checker=requires_safety_checker)
234-
self.default_sample_size = self.transformer.config.sample_size
236+
self.default_sample_size = (
237+
self.transformer.config.sample_size
238+
if hasattr(self, "transformer") and self.transformer is not None
239+
else 128
240+
)
235241

236242
def encode_prompt(
237243
self,
238244
prompt: str,
239-
device: torch.device,
240-
dtype: torch.dtype,
245+
device: torch.device = None,
246+
dtype: torch.dtype = None,
241247
num_images_per_prompt: int = 1,
242248
do_classifier_free_guidance: bool = True,
243249
negative_prompt: Optional[str] = None,
@@ -281,6 +287,17 @@ def encode_prompt(
281287
text_encoder_index (`int`, *optional*):
282288
Index of the text encoder to use. `0` for clip and `1` for T5.
283289
"""
290+
if dtype is None:
291+
if self.text_encoder_2 is not None:
292+
dtype = self.text_encoder_2.dtype
293+
elif self.transformer is not None:
294+
dtype = self.transformer.dtype
295+
else:
296+
dtype = None
297+
298+
if device is None:
299+
device = self._execution_device
300+
284301
tokenizers = [self.tokenizer, self.tokenizer_2]
285302
text_encoders = [self.text_encoder, self.text_encoder_2]
286303

0 commit comments

Comments
 (0)