From 36eee4022dbbf54fbe2d8c8d42dec6d5952677a8 Mon Sep 17 00:00:00 2001
From: shitao <2906698981@qq.com>
Date: Sat, 30 Nov 2024 18:01:05 +0800
Subject: [PATCH 01/55] OmniGen model.py
---
src/diffusers/models/embeddings.py | 127 +++
src/diffusers/models/normalization.py | 2 +-
.../transformers/transformer_omnigen.py | 345 ++++++++
src/diffusers/pipelines/omnigen/__init__.py | 67 ++
.../pipelines/omnigen/pipeline_omnigen.py | 774 ++++++++++++++++++
.../pipelines/omnigen/pipeline_output.py | 24 +
6 files changed, 1338 insertions(+), 1 deletion(-)
create mode 100644 src/diffusers/models/transformers/transformer_omnigen.py
create mode 100644 src/diffusers/pipelines/omnigen/__init__.py
create mode 100644 src/diffusers/pipelines/omnigen/pipeline_omnigen.py
create mode 100644 src/diffusers/pipelines/omnigen/pipeline_output.py
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 80775d477c0d..1eedf55933e0 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -288,6 +288,91 @@ def forward(self, latent):
return (latent + pos_embed).to(latent.dtype)
+
+class OmniGenPatchEmbed(nn.Module):
+ """2D Image to Patch Embedding with support for OmniGen."""
+
+ def __init__(
+ self,
+ patch_size: int =2,
+ in_channels: int =4,
+ embed_dim: int =768,
+ bias: bool =True,
+ interpolation_scale: float =1,
+ pos_embed_max_size: int =192,
+ base_size: int =64,
+ ):
+ super().__init__()
+
+ self.output_image_proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ self.input_image_proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+
+ self.patch_size = patch_size
+ self.interpolation_scale = interpolation_scale
+ self.pos_embed_max_size = pos_embed_max_size
+
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim, self.pos_embed_max_size, base_size=base_size, interpolation_scale=self.interpolation_scale
+ )
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
+
+ def cropped_pos_embed(self, height, width):
+ """Crops positional embeddings for SD3 compatibility."""
+ if self.pos_embed_max_size is None:
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
+
+ height = height // self.patch_size
+ width = width // self.patch_size
+ if height > self.pos_embed_max_size:
+ raise ValueError(
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
+ )
+ if width > self.pos_embed_max_size:
+ raise ValueError(
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
+ )
+
+ top = (self.pos_embed_max_size - height) // 2
+ left = (self.pos_embed_max_size - width) // 2
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
+ return spatial_pos_embed
+
+ def patch_embeddings(self, latent, is_input_image: bool):
+ if is_input_image:
+ latent = self.input_image_proj(latent)
+ else:
+ latent = self.output_image_proj(latent)
+ latent = latent.flatten(2).transpose(1, 2)
+ return latent
+
+ def forward(self, latent, is_input_image: bool, padding_latent=None):
+ if isinstance(latent, list):
+ if padding_latent is None:
+ padding_latent = [None] * len(latent)
+ patched_latents, num_tokens, shapes = [], [], []
+ for sub_latent, padding in zip(latent, padding_latent):
+ height, width = sub_latent.shape[-2:]
+ sub_latent = self.patch_embeddings(sub_latent, is_input_image)
+ pos_embed = self.cropped_pos_embed(height, width)
+ sub_latent = sub_latent + pos_embed
+ if padding is not None:
+ sub_latent = torch.cat([sub_latent, padding], dim=-2)
+ patched_latents.append(sub_latent)
+ else:
+ height, width = latent.shape[-2:]
+ pos_embed = self.cropped_pos_embed(height, width)
+ latent = self.patch_embeddings(latent, is_input_image)
+ latent = latent + pos_embed
+
+ return latent
+
+
class LuminaPatchEmbed(nn.Module):
"""2D Image to Patch Embedding with support for Lumina-T2X"""
@@ -935,6 +1020,48 @@ def forward(self, timesteps):
return t_emb
+class OmniGenTimestepEmbed(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations for OmniGen
+ """
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t, dtype=torch.float32):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""
diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py
index 817b3fff2ea6..d19e79ad3b52 100644
--- a/src/diffusers/models/normalization.py
+++ b/src/diffusers/models/normalization.py
@@ -71,7 +71,7 @@ def forward(
if self.chunk_dim == 1:
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
- # other if-branch. This branch is specific to CogVideoX for now.
+ # other if-branch. This branch is specific to CogVideoX and OmniGen for now.
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py
new file mode 100644
index 000000000000..a926c6dd43d3
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_omnigen.py
@@ -0,0 +1,345 @@
+# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, Optional, Tuple, Union, List
+
+import torch
+from torch import nn
+import torch.utils.checkpoint
+
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers import Phi3Model, Phi3Config
+from transformers.cache_utils import Cache, DynamicCache
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...utils import logging
+from ..attention_processor import AttentionProcessor
+from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
+from ..embeddings import OmniGenPatchEmbed, OmniGenTimestepEmbed
+from ..modeling_utils import ModelMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+
+class OmniGenBaseTransformer(Phi3Model):
+ """
+ Transformer used in OmniGen. The transformer block is from Ph3, and only modify the attention mask.
+ References: [OmniGen](https://arxiv.org/pdf/2409.11340)
+
+ Parameters:
+ config: Phi3Config
+ """
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ offload_model: Optional[bool] = False,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
+
+ if attention_mask is not None and attention_mask.dim() == 3:
+ dtype = inputs_embeds.dtype
+ min_dtype = torch.finfo(dtype).min
+ attention_mask = (1 - attention_mask) * min_dtype
+ attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
+ else:
+ raise Exception("attention_mask parameter was unavailable or invalid")
+
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ layer_idx = -1
+ for decoder_layer in self.layers:
+ layer_idx += 1
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class OmniGenTransformer(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ """
+ The Transformer model introduced in OmniGen.
+
+ Reference: https://arxiv.org/pdf/2409.11340
+
+ Parameters:
+ patch_size (`int`, defaults to 2): Patch size to turn the input data into small patches.
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
+ """
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ transformer_config: Phi3Config,
+ patch_size=2,
+ in_channels=4,
+ pos_embed_max_size: int = 192,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.patch_size = patch_size
+ self.pos_embed_max_size = pos_embed_max_size
+
+ hidden_size = transformer_config.hidden_size
+
+ self.patch_embedding = OmniGenPatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, pos_embed_max_size=pos_embed_max_size)
+
+ self.time_token = OmniGenTimestepEmbed(hidden_size)
+ self.t_embedder = OmniGenTimestepEmbed(hidden_size)
+
+ self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1)
+ self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.llm = OmniGenBaseTransformer(config=transformer_config)
+ self.llm.config.use_cache = False
+
+
+ def unpatchify(self, x, h, w):
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+
+ x = x.reshape(
+ shape=(x.shape[0], h // self.patch_size, w // self.patch_size, self.patch_size, self.patch_size, c))
+ x = torch.einsum('nhwpqc->nchpwq', x)
+ imgs = x.reshape(shape=(x.shape[0], c, h, w))
+ return imgs
+
+
+ def prepare_condition_embeddings(self, input_ids, input_img_latents, input_image_sizes, padding_latent):
+ condition_embeds = None
+ if input_img_latents is not None:
+ input_latents = self.patch_embedding(input_img_latents, is_input_images=True, padding_latent=padding_latent)
+ if input_ids is not None:
+ condition_embeds = self.llm.embed_tokens(input_ids).clone()
+ input_img_inx = 0
+ for b_inx in input_image_sizes.keys():
+ for start_inx, end_inx in input_image_sizes[b_inx]:
+ condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
+ input_img_inx += 1
+ if input_img_latents is not None:
+ assert input_img_inx == len(input_latents)
+ return condition_embeds
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def forward(self,
+ hidden_states,
+ timestep,
+ input_ids,
+ input_img_latents,
+ input_image_sizes,
+ attention_mask,
+ position_ids,
+ padding_latent=None,
+ past_key_values=None,
+ return_past_key_values=True,
+ offload_model: bool = False):
+
+ height, width = hidden_states.size(-2)
+ hidden_states = self.patch_embedding(hidden_states, is_input_image=False)
+ num_tokens_for_output_image = hidden_states.size(1)
+
+ time_token = self.time_token(timestep, dtype=hidden_states.dtype).unsqueeze(1)
+
+ condition_embeds = self.prepare_condition_embeddings(input_ids, input_img_latents, input_image_sizes, padding_latent)
+ if condition_embeds is not None:
+ input_emb = torch.cat([condition_embeds, time_token, hidden_states], dim=1)
+ else:
+ input_emb = torch.cat([time_token, hidden_states], dim=1)
+ output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids,
+ past_key_values=past_key_values, offload_model=offload_model)
+ output, past_key_values = output.last_hidden_state, output.past_key_values
+
+ image_embedding = output[:, -num_tokens_for_output_image:]
+ time_emb = self.t_embedder(timestep, dtype=hidden_states.dtype)
+ x = self.final_layer(image_embedding, time_emb)
+ latents = self.unpatchify(x, height, width)
+
+ if return_past_key_values:
+ return latents, past_key_values
+ return latents
+
+
+
+
+
diff --git a/src/diffusers/pipelines/omnigen/__init__.py b/src/diffusers/pipelines/omnigen/__init__.py
new file mode 100644
index 000000000000..3570368a5ca1
--- /dev/null
+++ b/src/diffusers/pipelines/omnigen/__init__.py
@@ -0,0 +1,67 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_additional_imports = {}
+_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["modeling_flux"] = ["ReduxImageEncoder"]
+ _import_structure["pipeline_flux"] = ["FluxPipeline"]
+ _import_structure["pipeline_flux_control"] = ["FluxControlPipeline"]
+ _import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"]
+ _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
+ _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
+ _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
+ _import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
+ _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
+ _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
+ _import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .modeling_flux import ReduxImageEncoder
+ from .pipeline_flux import FluxPipeline
+ from .pipeline_flux_control import FluxControlPipeline
+ from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline
+ from .pipeline_flux_controlnet import FluxControlNetPipeline
+ from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
+ from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
+ from .pipeline_flux_fill import FluxFillPipeline
+ from .pipeline_flux_img2img import FluxImg2ImgPipeline
+ from .pipeline_flux_inpaint import FluxInpaintPipeline
+ from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
+ for name, value in _additional_imports.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
new file mode 100644
index 000000000000..f6f3752fe97d
--- /dev/null
+++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
@@ -0,0 +1,774 @@
+# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models.autoencoders import AutoencoderKL
+from ...models.transformers import FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import OmniGenPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import OmniGenPipeline
+
+ >>> pipe = OmniGenPipeline.from_pretrained("****", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
+ >>> image.save("flux.png")
+ ```
+"""
+
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class OmniGenPipeline(
+ DiffusionPipeline,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+):
+ r"""
+ The OmniGen pipeline for multimodal-to-image generation.
+
+ Reference: https://arxiv.org/pdf/2409.11340
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ tokenizer: CLIPTokenizer,
+ transformer: FluxTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ return latents, latent_image_ids
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/omnigen/pipeline_output.py b/src/diffusers/pipelines/omnigen/pipeline_output.py
new file mode 100644
index 000000000000..40ff24199900
--- /dev/null
+++ b/src/diffusers/pipelines/omnigen/pipeline_output.py
@@ -0,0 +1,24 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class OmniGenPipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+
From bbe2b98e03f44e895b63555c22cbe5a44fca36b6 Mon Sep 17 00:00:00 2001
From: staoxiao <2906698981@qq.com>
Date: Sat, 30 Nov 2024 22:14:27 +0800
Subject: [PATCH 02/55] update OmniGenTransformerModel
---
src/diffusers/__init__.py | 2 +
src/diffusers/models/__init__.py | 2 +
src/diffusers/models/transformers/__init__.py | 1 +
.../transformers/transformer_omnigen.py | 5 +-
test.py | 52 +++++++++++++++++++
5 files changed, 60 insertions(+), 2 deletions(-)
create mode 100644 test.py
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index a4749af5f61b..2d5cd10fd66f 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -108,6 +108,7 @@
"MotionAdapter",
"MultiAdapter",
"MultiControlNetModel",
+ "OmniGenTransformerModel",
"PixArtTransformer2DModel",
"PriorTransformer",
"SD3ControlNetModel",
@@ -599,6 +600,7 @@
MotionAdapter,
MultiAdapter,
MultiControlNetModel,
+ OmniGenTransformerModel,
PixArtTransformer2DModel,
PriorTransformer,
SD3ControlNetModel,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 65e2418ac794..faf9b97e827c 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -66,6 +66,7 @@
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
+ _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformerModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
@@ -125,6 +126,7 @@
LatteTransformer3DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
+ OmniGenTransformerModel,
PixArtTransformer2DModel,
PriorTransformer,
SD3Transformer2DModel,
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index a2c087d708a4..7d0fd364a17d 100644
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -20,3 +20,4 @@
from .transformer_mochi import MochiTransformer3DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
+ from .transformer_omnigen import OmniGenTransformerModel
diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py
index a926c6dd43d3..2ea18512724c 100644
--- a/src/diffusers/models/transformers/transformer_omnigen.py
+++ b/src/diffusers/models/transformers/transformer_omnigen.py
@@ -26,7 +26,7 @@
from ...loaders import PeftAdapterMixin
from ...utils import logging
from ..attention_processor import AttentionProcessor
-from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
+from ..normalization import AdaLayerNorm
from ..embeddings import OmniGenPatchEmbed, OmniGenTimestepEmbed
from ..modeling_utils import ModelMixin
@@ -162,7 +162,7 @@ def forward(
)
-class OmniGenTransformer(ModelMixin, ConfigMixin, PeftAdapterMixin):
+class OmniGenTransformerModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
The Transformer model introduced in OmniGen.
@@ -343,3 +343,4 @@ def forward(self,
+
diff --git a/test.py b/test.py
new file mode 100644
index 000000000000..e110b93a7b74
--- /dev/null
+++ b/test.py
@@ -0,0 +1,52 @@
+import os
+os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2'
+
+from huggingface_hub import snapshot_download
+
+from diffusers.models import OmniGenTransformerModel
+from transformers import Phi3Model, Phi3Config
+
+
+from safetensors.torch import load_file
+
+model_name = "Shitao/OmniGen-v1"
+config = Phi3Config.from_pretrained("Shitao/OmniGen-v1")
+model = OmniGenTransformerModel(transformer_config=config)
+cache_folder = os.getenv('HF_HUB_CACHE')
+model_name = snapshot_download(repo_id=model_name,
+ cache_dir=cache_folder,
+ ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
+print(model_name)
+model_path = os.path.join(model_name, 'model.safetensors')
+ckpt = load_file(model_path, 'cpu')
+
+
+mapping_dict = {
+ "pos_embed": "patch_embedding.pos_embed",
+ "x_embedder.proj.weight": "patch_embedding.output_image_proj.weight",
+ "x_embedder.proj.bias": "patch_embedding.output_image_proj.bias",
+ "input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight",
+ "input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias",
+ "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
+ "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
+ "final_layer.linear.weight": "proj_out.weight",
+ "final_layer.linear.bias": "proj_out.bias",
+
+}
+
+new_ckpt = {}
+for k, v in ckpt.items():
+ # new_ckpt[k] = v
+ if k in mapping_dict:
+ new_ckpt[mapping_dict[k]] = v
+ else:
+ new_ckpt[k] = v
+
+
+
+model.load_state_dict(new_ckpt)
+
+
+
+
+
From b839590eb3a629cd1f87f84b8128593ddbca9adf Mon Sep 17 00:00:00 2001
From: shitao <2906698981@qq.com>
Date: Mon, 2 Dec 2024 17:34:58 +0800
Subject: [PATCH 03/55] omnigen pipeline
---
scripts/convert_omnigen_to_diffusers.py | 73 +++
src/diffusers/__init__.py | 6 +-
src/diffusers/models/__init__.py | 4 +-
src/diffusers/models/embeddings.py | 17 +-
src/diffusers/models/transformers/__init__.py | 2 +-
.../transformers/transformer_omnigen.py | 149 +++--
.../pipelines/lumina/pipeline_lumina.py | 4 +-
src/diffusers/pipelines/omnigen/__init__.py | 37 +-
.../pipelines/omnigen/kvcache_omnigen.py | 106 ++++
.../pipelines/omnigen/pipeline_omnigen.py | 533 ++++++------------
.../pipelines/omnigen/pipeline_output.py | 24 -
.../pipelines/omnigen/processor_omnigen.py | 295 ++++++++++
src/diffusers/utils/dummy_pt_objects.py | 14 +
.../dummy_torch_and_transformers_objects.py | 15 +
test.py | 4 +-
15 files changed, 832 insertions(+), 451 deletions(-)
create mode 100644 scripts/convert_omnigen_to_diffusers.py
create mode 100644 src/diffusers/pipelines/omnigen/kvcache_omnigen.py
delete mode 100644 src/diffusers/pipelines/omnigen/pipeline_output.py
create mode 100644 src/diffusers/pipelines/omnigen/processor_omnigen.py
diff --git a/scripts/convert_omnigen_to_diffusers.py b/scripts/convert_omnigen_to_diffusers.py
new file mode 100644
index 000000000000..2b02f6a064dc
--- /dev/null
+++ b/scripts/convert_omnigen_to_diffusers.py
@@ -0,0 +1,73 @@
+import argparse
+import os
+
+import torch
+from safetensors.torch import load_file
+from transformers import AutoModel, AutoTokenizer, AutoConfig
+
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenTransformer2DModel, OmniGenPipeline
+
+
+def main(args):
+ # checkpoint from https://huggingface.co/Shitao/OmniGen-v1
+ ckpt = load_file(args.origin_ckpt_path, device="cpu")
+
+ mapping_dict = {
+ "pos_embed": "patch_embedding.pos_embed",
+ "x_embedder.proj.weight": "patch_embedding.output_image_proj.weight",
+ "x_embedder.proj.bias": "patch_embedding.output_image_proj.bias",
+ "input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight",
+ "input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias",
+ "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
+ "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
+ "final_layer.linear.weight": "proj_out.weight",
+ "final_layer.linear.bias": "proj_out.bias",
+
+ }
+
+ converted_state_dict = {}
+ for k, v in ckpt.items():
+ # new_ckpt[k] = v
+ if k in mapping_dict:
+ converted_state_dict[mapping_dict[k]] = v
+ else:
+ converted_state_dict[k] = v
+
+ transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path)
+
+ # Lumina-Next-SFT 2B
+ transformer = OmniGenTransformer2DModel(
+ transformer_config=transformer_config,
+ patch_size=2,
+ in_channels=4,
+ pos_embed_max_size=192,
+ )
+ transformer.load_state_dict(converted_state_dict, strict=True)
+
+ num_model_params = sum(p.numel() for p in transformer.parameters())
+ print(f"Total number of transformer parameters: {num_model_params}")
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ vae = AutoencoderKL.from_pretrained(args.origin_ckpt_path, torch_dtype=torch.float32)
+
+ tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path)
+
+
+ pipeline = OmniGenPipeline(
+ tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler
+ )
+ pipeline.save_pretrained(args.dump_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
+ )
+
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
+
+ args = parser.parse_args()
+ main(args)
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 2d5cd10fd66f..c7868cf1c304 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -108,7 +108,7 @@
"MotionAdapter",
"MultiAdapter",
"MultiControlNetModel",
- "OmniGenTransformerModel",
+ "OmniGenTransformer2DModel",
"PixArtTransformer2DModel",
"PriorTransformer",
"SD3ControlNetModel",
@@ -321,6 +321,7 @@
"MarigoldNormalsPipeline",
"MochiPipeline",
"MusicLDMPipeline",
+ "OmniGenPipeline",
"PaintByExamplePipeline",
"PIAPipeline",
"PixArtAlphaPipeline",
@@ -600,7 +601,7 @@
MotionAdapter,
MultiAdapter,
MultiControlNetModel,
- OmniGenTransformerModel,
+ OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
SD3ControlNetModel,
@@ -792,6 +793,7 @@
MarigoldNormalsPipeline,
MochiPipeline,
MusicLDMPipeline,
+ OmniGenPipeline,
PaintByExamplePipeline,
PIAPipeline,
PixArtAlphaPipeline,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index faf9b97e827c..3ce71406059b 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -66,7 +66,7 @@
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
- _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformerModel"]
+ _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
@@ -126,7 +126,7 @@
LatteTransformer3DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
- OmniGenTransformerModel,
+ OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
SD3Transformer2DModel,
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 1eedf55933e0..9681a84b8878 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -351,7 +351,20 @@ def patch_embeddings(self, latent, is_input_image: bool):
latent = latent.flatten(2).transpose(1, 2)
return latent
- def forward(self, latent, is_input_image: bool, padding_latent=None):
+ def forward(self,
+ latent: torch.Tensor,
+ is_input_image: bool,
+ padding_latent: torch.Tensor = None
+ ):
+ """
+ Args:
+ latent:
+ is_input_image:
+ padding_latent: When sizes of target images are inconsistent, use `padding_latent` to maintain consistent sequence length.
+
+ Returns: torch.Tensor
+
+ """
if isinstance(latent, list):
if padding_latent is None:
padding_latent = [None] * len(latent)
@@ -362,7 +375,7 @@ def forward(self, latent, is_input_image: bool, padding_latent=None):
pos_embed = self.cropped_pos_embed(height, width)
sub_latent = sub_latent + pos_embed
if padding is not None:
- sub_latent = torch.cat([sub_latent, padding], dim=-2)
+ sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2)
patched_latents.append(sub_latent)
else:
height, width = latent.shape[-2:]
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 7d0fd364a17d..9770ded5e31e 100644
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -20,4 +20,4 @@
from .transformer_mochi import MochiTransformer3DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
- from .transformer_omnigen import OmniGenTransformerModel
+ from .transformer_omnigen import OmniGenTransformer2DModel
diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py
index 2ea18512724c..2fa97146f1d3 100644
--- a/src/diffusers/models/transformers/transformer_omnigen.py
+++ b/src/diffusers/models/transformers/transformer_omnigen.py
@@ -12,29 +12,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Optional, Tuple, Union, List
+from typing import Any, Dict, Optional, Tuple, Union, List
+from dataclasses import dataclass
import torch
-from torch import nn
import torch.utils.checkpoint
-
-from transformers.modeling_outputs import BaseModelOutputWithPast
+from torch import nn
+from transformers.cache_utils import DynamicCache
from transformers import Phi3Model, Phi3Config
from transformers.cache_utils import Cache, DynamicCache
+from transformers.modeling_outputs import BaseModelOutputWithPast
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
-from ...utils import logging
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers
from ..attention_processor import AttentionProcessor
-from ..normalization import AdaLayerNorm
from ..embeddings import OmniGenPatchEmbed, OmniGenTimestepEmbed
+from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+@dataclass
+class OmniGen2DModelOutput(Transformer2DModelOutput):
+ """
+ The output of [`Transformer2DModel`].
+
+ Args:
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
+ distributions for the unnoised latent pixels.
+ past_key_values (`transformers.cache_utils.DynamicCache`)
+ """
+
+ sample: "torch.Tensor" # noqa: F821
+ past_key_values: "DynamicCache"
+
+
class OmniGenBaseTransformer(Phi3Model):
"""
Transformer used in OmniGen. The transformer block is from Ph3, and only modify the attention mask.
@@ -44,6 +62,37 @@ class OmniGenBaseTransformer(Phi3Model):
config: Phi3Config
"""
+ def prefetch_layer(self, layer_idx: int, device: torch.device):
+ "Starts prefetching the next layer cache"
+ with torch.cuda.stream(self.prefetch_stream):
+ # Prefetch next layer tensors to GPU
+ for name, param in self.layers[layer_idx].named_parameters():
+ param.data = param.data.to(device, non_blocking=True)
+
+ def evict_previous_layer(self, layer_idx: int):
+ "Moves the previous layer cache to the CPU"
+ prev_layer_idx = layer_idx - 1
+ for name, param in self.layers[prev_layer_idx].named_parameters():
+ param.data = param.data.to("cpu", non_blocking=True)
+
+ def get_offload_layer(self, layer_idx: int, device: torch.device):
+ # init stream
+ if not hasattr(self, "prefetch_stream"):
+ self.prefetch_stream = torch.cuda.Stream()
+
+ # delete previous layer
+ # main stream sync shouldn't be necessary since all computation on iter i-1 is finished by iter i
+ # torch.cuda.current_stream().synchronize()
+ # avoid extra eviction of last layer
+ if layer_idx > 0:
+ self.evict_previous_layer(layer_idx)
+
+ # make sure the current layer is ready
+ self.prefetch_stream.synchronize()
+
+ # load next layer
+ self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
+
def forward(
self,
input_ids: torch.LongTensor = None,
@@ -56,7 +105,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- offload_model: Optional[bool] = False,
+ offload_transformer_block: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -124,6 +173,13 @@ def forward(
cache_position,
)
else:
+ if offload_transformer_block and not self.training:
+ if not not torch.cuda.is_available():
+ logger.warning_once(
+ "We don't detecte any available GPU, so diable `offload_transformer_block`"
+ )
+ else:
+ self.get_offload_layer(layer_idx, device=inputs_embeds.device)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
@@ -162,7 +218,7 @@ def forward(
)
-class OmniGenTransformerModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
The Transformer model introduced in OmniGen.
@@ -197,7 +253,10 @@ def __init__(
hidden_size = transformer_config.hidden_size
- self.patch_embedding = OmniGenPatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, pos_embed_max_size=pos_embed_max_size)
+ self.patch_embedding = OmniGenPatchEmbed(patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=hidden_size,
+ pos_embed_max_size=pos_embed_max_size)
self.time_token = OmniGenTimestepEmbed(hidden_size)
self.t_embedder = OmniGenTimestepEmbed(hidden_size)
@@ -208,7 +267,6 @@ def __init__(
self.llm = OmniGenBaseTransformer(config=transformer_config)
self.llm.config.use_cache = False
-
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
@@ -222,11 +280,10 @@ def unpatchify(self, x, h, w):
imgs = x.reshape(shape=(x.shape[0], c, h, w))
return imgs
-
- def prepare_condition_embeddings(self, input_ids, input_img_latents, input_image_sizes, padding_latent):
+ def prepare_condition_embeddings(self, input_ids, input_img_latents, input_image_sizes):
condition_embeds = None
if input_img_latents is not None:
- input_latents = self.patch_embedding(input_img_latents, is_input_images=True, padding_latent=padding_latent)
+ input_latents = self.patch_embedding(input_img_latents, is_input_images=True)
if input_ids is not None:
condition_embeds = self.llm.embed_tokens(input_ids).clone()
input_img_inx = 0
@@ -303,44 +360,54 @@ def _set_gradient_checkpointing(self, module, value=False):
module.gradient_checkpointing = value
def forward(self,
- hidden_states,
- timestep,
- input_ids,
- input_img_latents,
- input_image_sizes,
- attention_mask,
- position_ids,
- padding_latent=None,
- past_key_values=None,
- return_past_key_values=True,
- offload_model: bool = False):
-
- height, width = hidden_states.size(-2)
+ hidden_states: torch.Tensor,
+ timestep: Union[int, float, torch.LongTensor],
+ condition_tokens: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_ids: torch.Tensor,
+ past_key_values: DynamicCache = None,
+ offload_transformer_block: bool = False,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ):
+
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ height, width = hidden_states.size(-2)
hidden_states = self.patch_embedding(hidden_states, is_input_image=False)
num_tokens_for_output_image = hidden_states.size(1)
time_token = self.time_token(timestep, dtype=hidden_states.dtype).unsqueeze(1)
- condition_embeds = self.prepare_condition_embeddings(input_ids, input_img_latents, input_image_sizes, padding_latent)
- if condition_embeds is not None:
- input_emb = torch.cat([condition_embeds, time_token, hidden_states], dim=1)
+ if condition_tokens is not None:
+ input_emb = torch.cat([condition_tokens, time_token, hidden_states], dim=1)
else:
input_emb = torch.cat([time_token, hidden_states], dim=1)
- output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids,
- past_key_values=past_key_values, offload_model=offload_model)
+ output = self.llm(inputs_embeds=input_emb,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ offload_transformer_block=offload_transformer_block)
output, past_key_values = output.last_hidden_state, output.past_key_values
image_embedding = output[:, -num_tokens_for_output_image:]
time_emb = self.t_embedder(timestep, dtype=hidden_states.dtype)
x = self.final_layer(image_embedding, time_emb)
- latents = self.unpatchify(x, height, width)
-
- if return_past_key_values:
- return latents, past_key_values
- return latents
-
-
-
-
-
+ output = self.unpatchify(x, height, width)
+ if not return_dict:
+ return (output, past_key_values)
+ return OmniGen2DModelOutput(sample=output, past_key_values=past_key_values)
diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py
index 018f2e8bf1bc..296ae5303b20 100644
--- a/src/diffusers/pipelines/lumina/pipeline_lumina.py
+++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py
@@ -25,7 +25,7 @@
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL
from ...models.embeddings import get_2d_rotary_pos_embed_lumina
-from ...models.transformers.lumina_nextdit2d import LuminaNextDiT2DModel
+from ...models.transformers.lumina_nextdit2d import LuminaextDiT2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
BACKENDS_MAPPING,
@@ -777,7 +777,7 @@ def __call__(
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
+ self.scheduler, num_inference_steps, device, timesteps,
)
# 5. Prepare latents.
diff --git a/src/diffusers/pipelines/omnigen/__init__.py b/src/diffusers/pipelines/omnigen/__init__.py
index 3570368a5ca1..557e7c08dc22 100644
--- a/src/diffusers/pipelines/omnigen/__init__.py
+++ b/src/diffusers/pipelines/omnigen/__init__.py
@@ -11,8 +11,8 @@
_dummy_objects = {}
-_additional_imports = {}
-_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]}
+_import_structure = {}
+
try:
if not (is_transformers_available() and is_torch_available()):
@@ -22,35 +22,20 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
- _import_structure["modeling_flux"] = ["ReduxImageEncoder"]
- _import_structure["pipeline_flux"] = ["FluxPipeline"]
- _import_structure["pipeline_flux_control"] = ["FluxControlPipeline"]
- _import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"]
- _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
- _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
- _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
- _import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
- _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
- _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
- _import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
+ _import_structure["pipeline_omnigen"] = ["OmniGenPipeline"]
+
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
+
except OptionalDependencyNotAvailable:
- from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ from ...utils.dummy_torch_and_transformers_objects import *
else:
- from .modeling_flux import ReduxImageEncoder
- from .pipeline_flux import FluxPipeline
- from .pipeline_flux_control import FluxControlPipeline
- from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline
- from .pipeline_flux_controlnet import FluxControlNetPipeline
- from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
- from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
- from .pipeline_flux_fill import FluxFillPipeline
- from .pipeline_flux_img2img import FluxImg2ImgPipeline
- from .pipeline_flux_inpaint import FluxInpaintPipeline
- from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
+ from .pipeline_omnigen import OmniGenPipeline
+
+
else:
import sys
@@ -63,5 +48,3 @@
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
- for name, value in _additional_imports.items():
- setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/omnigen/kvcache_omnigen.py b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py
new file mode 100644
index 000000000000..0270292c130f
--- /dev/null
+++ b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py
@@ -0,0 +1,106 @@
+from typing import Optional, Dict, Any, Tuple, List
+
+import torch
+from transformers.cache_utils import DynamicCache
+
+
+class OmniGenCache(DynamicCache):
+ def __init__(self,
+ num_tokens_for_img: int, offload_kv_cache: bool = False) -> None:
+ if not torch.cuda.is_available():
+ raise RuntimeError(
+ "OmniGenCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!")
+ super().__init__()
+ self.original_device = []
+ self.prefetch_stream = torch.cuda.Stream()
+ self.num_tokens_for_img = num_tokens_for_img
+ self.offload_kv_cache = offload_kv_cache
+
+ def prefetch_layer(self, layer_idx: int):
+ "Starts prefetching the next layer cache"
+ if layer_idx < len(self):
+ with torch.cuda.stream(self.prefetch_stream):
+ # Prefetch next layer tensors to GPU
+ device = self.original_device[layer_idx]
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
+
+ def evict_previous_layer(self, layer_idx: int):
+ "Moves the previous layer cache to the CPU"
+ if len(self) > 2:
+ # We do it on the default stream so it occurs after all earlier computations on these tensors are done
+ if layer_idx == 0:
+ prev_layer_idx = -1
+ else:
+ prev_layer_idx = (layer_idx - 1) % len(self)
+ self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
+ self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
+
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+ "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
+ if layer_idx < len(self):
+ if self.offload_kv_cache:
+ # Evict the previous layer if necessary
+ torch.cuda.current_stream().synchronize()
+ self.evict_previous_layer(layer_idx)
+ # Load current layer cache to its original device if not already there
+ # original_device = self.original_device[layer_idx]
+ # self.prefetch_stream.synchronize(original_device)
+ self.prefetch_stream.synchronize()
+ key_tensor = self.key_cache[layer_idx]
+ value_tensor = self.value_cache[layer_idx]
+
+ # Prefetch the next layer
+ self.prefetch_layer((layer_idx + 1) % len(self))
+ else:
+ key_tensor = self.key_cache[layer_idx]
+ value_tensor = self.value_cache[layer_idx]
+ return (key_tensor, value_tensor)
+ else:
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ # Update the cache
+ if len(self.key_cache) < layer_idx:
+ raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
+ elif len(self.key_cache) == layer_idx:
+ # only cache the states for condition tokens
+ key_states = key_states[..., :-(self.num_tokens_for_img + 1), :]
+ value_states = value_states[..., :-(self.num_tokens_for_img + 1), :]
+
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += key_states.shape[-2]
+
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ self.original_device.append(key_states.device)
+ if self.offload_kv_cache:
+ self.evict_previous_layer(layer_idx)
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+ else:
+ # only cache the states for condition tokens
+ key_tensor, value_tensor = self[layer_idx]
+ k = torch.cat([key_tensor, key_states], dim=-2)
+ v = torch.cat([value_tensor, value_states], dim=-2)
+ return k, v
diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
index f6f3752fe97d..9c9a67e53e9c 100644
--- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
@@ -15,27 +15,23 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
-import numpy as np
import torch
-from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+from transformers import LlamaTokenizer
from ...image_processor import VaeImageProcessor
-from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...loaders import FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
-from ...models.transformers import FluxTransformer2DModel
+from ...models.transformers import OmniGenTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
- USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
- scale_lora_layers,
- unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline
-from .pipeline_output import OmniGenPipelineOutput
-
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from .processor_omnigen import OmniGenMultiModalProcessor
+from .kvcache_omnigen import OmniGenCache
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
@@ -136,24 +132,15 @@ class OmniGenPipeline(
Reference: https://arxiv.org/pdf/2409.11340
Args:
- transformer ([`FluxTransformer2DModel`]):
- Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ transformer ([`OmniGenTransformer2DModel`]):
+ Autoregressive Transformer architecture for OmniGen.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
- text_encoder ([`CLIPTextModel`]):
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
- text_encoder_2 ([`T5EncoderModel`]):
- [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
- the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
- tokenizer (`CLIPTokenizer`):
- Tokenizer of class
- [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
- tokenizer_2 (`T5TokenizerFast`):
- Second Tokenizer of class
- [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ tokenizer (`LlamaTokenizer`):
+ Text tokenizer of class.
+ [LlamaTokenizer](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaTokenizer).
"""
model_cpu_offload_seq = "transformer->vae"
@@ -162,216 +149,110 @@ class OmniGenPipeline(
def __init__(
self,
+ transformer: OmniGenTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
- tokenizer: CLIPTokenizer,
- transformer: FluxTransformer2DModel,
+ tokenizer: LlamaTokenizer,
):
super().__init__()
self.register_modules(
vae=vae,
- text_encoder=text_encoder,
- text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
- tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
- # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # OmniGen latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+
+ self.multimodal_processor = OmniGenMultiModalProcessor(tokenizer, max_image_size=1024)
self.tokenizer_max_length = (
- self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 120000
)
self.default_sample_size = 128
- def _get_t5_prompt_embeds(
- self,
- prompt: Union[str, List[str]] = None,
- num_images_per_prompt: int = 1,
- max_sequence_length: int = 512,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- ):
- device = device or self._execution_device
- dtype = dtype or self.text_encoder.dtype
-
- prompt = [prompt] if isinstance(prompt, str) else prompt
- batch_size = len(prompt)
-
- if isinstance(self, TextualInversionLoaderMixin):
- prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
-
- text_inputs = self.tokenizer_2(
- prompt,
- padding="max_length",
- max_length=max_sequence_length,
- truncation=True,
- return_length=False,
- return_overflowing_tokens=False,
- return_tensors="pt",
- )
- text_input_ids = text_inputs.input_ids
- untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
-
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
- removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
- logger.warning(
- "The following part of your input was truncated because `max_sequence_length` is set to "
- f" {max_sequence_length} tokens: {removed_text}"
- )
-
- prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
-
- dtype = self.text_encoder_2.dtype
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
-
- _, seq_len, _ = prompt_embeds.shape
-
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
-
- return prompt_embeds
-
- def _get_clip_prompt_embeds(
- self,
- prompt: Union[str, List[str]],
- num_images_per_prompt: int = 1,
- device: Optional[torch.device] = None,
- ):
- device = device or self._execution_device
-
- prompt = [prompt] if isinstance(prompt, str) else prompt
- batch_size = len(prompt)
-
- if isinstance(self, TextualInversionLoaderMixin):
- prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
-
- text_inputs = self.tokenizer(
- prompt,
- padding="max_length",
- max_length=self.tokenizer_max_length,
- truncation=True,
- return_overflowing_tokens=False,
- return_length=False,
- return_tensors="pt",
- )
-
- text_input_ids = text_inputs.input_ids
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
- logger.warning(
- "The following part of your input was truncated because CLIP can only handle sequences up to"
- f" {self.tokenizer_max_length} tokens: {removed_text}"
- )
- prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
-
- # Use pooled output of CLIPTextModel
- prompt_embeds = prompt_embeds.pooler_output
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
-
- # duplicate text embeddings for each generation per prompt, using mps friendly method
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
-
- return prompt_embeds
-
- def encode_prompt(
- self,
- prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
- device: Optional[torch.device] = None,
- num_images_per_prompt: int = 1,
- prompt_embeds: Optional[torch.FloatTensor] = None,
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
- max_sequence_length: int = 512,
- lora_scale: Optional[float] = None,
+ def encod_input_iamges(
+ self,
+ input_pixel_values: List[torch.Tensor],
+ device: Optional[torch.device] = None,
):
- r"""
-
+ """
+ get the continues embedding of input images by VAE
Args:
- prompt (`str` or `List[str]`, *optional*):
- prompt to be encoded
- prompt_2 (`str` or `List[str]`, *optional*):
- The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
- used in all text-encoders
- device: (`torch.device`):
- torch device
- num_images_per_prompt (`int`):
- number of images that should be generated per prompt
- prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
- provided, text embeddings will be generated from `prompt` input argument.
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
- lora_scale (`float`, *optional*):
- A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ input_pixel_values: normlized pixel of input images
+ device:
+ Returns: torch.Tensor
"""
device = device or self._execution_device
+ dtype = self.vae.dtype
+
+ input_img_latents = []
+ for img in input_pixel_values:
+ img = self.vae.encode(img.to(device, dtype)).latent_dist.sample().mul_(self.vae.config.scaling_factor)
+ input_img_latents.append(img)
+ return input_img_latents
+
+ def get_multimodal_embeddings(self,
+ input_ids: torch.Tensor,
+ input_img_latents: List[torch.Tensor],
+ input_image_sizes: Dict,
+ device: Optional[torch.device] = None,
+ ):
+ """
+ get the multi-modal conditional embeddings
+ Args:
+ input_ids: a sequence of text id
+ input_img_latents: continues embedding of input images
+ input_image_sizes: the index of the input image in the input_ids sequence.
+ device:
- # set lora scale so that monkey patched LoRA
- # function of text encoder can correctly access it
- if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
- self._lora_scale = lora_scale
-
- # dynamically adjust the LoRA scale
- if self.text_encoder is not None and USE_PEFT_BACKEND:
- scale_lora_layers(self.text_encoder, lora_scale)
- if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
- scale_lora_layers(self.text_encoder_2, lora_scale)
-
- prompt = [prompt] if isinstance(prompt, str) else prompt
-
- if prompt_embeds is None:
- prompt_2 = prompt_2 or prompt
- prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
-
- # We only use the pooled prompt output from the CLIPTextModel
- pooled_prompt_embeds = self._get_clip_prompt_embeds(
- prompt=prompt,
- device=device,
- num_images_per_prompt=num_images_per_prompt,
- )
- prompt_embeds = self._get_t5_prompt_embeds(
- prompt=prompt_2,
- num_images_per_prompt=num_images_per_prompt,
- max_sequence_length=max_sequence_length,
- device=device,
- )
+ Returns: torch.Tensor
- if self.text_encoder is not None:
- if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ """
+ device = device or self._execution_device
- if self.text_encoder_2 is not None:
- if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder_2, lora_scale)
+ condition_tokens = None
+ if input_ids is not None:
+ condition_tokens = self.transformer.llm.embed_tokens(input_ids.to(device))
+ input_img_inx = 0
+ if input_img_latents is not None:
+ input_image_tokens = self.transformer.patch_embedding(input_img_latents,
+ is_input_images=True)
- dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
- text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+ for b_inx in input_image_sizes.keys():
+ for start_inx, end_inx in input_image_sizes[b_inx]:
+ # replace the placeholder in text tokens with the image embedding.
+ condition_tokens[b_inx, start_inx: end_inx] = input_image_tokens[input_img_inx].to(
+ condition_tokens.dtype)
+ input_img_inx += 1
- return prompt_embeds, pooled_prompt_embeds, text_ids
+ return condition_tokens
def check_inputs(
self,
prompt,
- prompt_2,
+ input_images,
height,
width,
- prompt_embeds=None,
- pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
+
+ if input_images is not None:
+ if len(input_images) != len(prompt):
+ raise ValueError(
+ f"The number of prompts: {len(prompt)} does not match the number of input images: {len(input_images)}."
+ )
+ for i in range(len(input_images)):
+ if not all(f"<|image_{k}|>" in prompt[i] for k in range(len(input_images[i]))):
+ raise ValueError(
+ f"prompt `{prompt[i]}` doesn't have enough placeholders for the input images `{input_images[i]}`"
+ )
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
@@ -384,33 +265,6 @@ def check_inputs(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
- if prompt is not None and prompt_embeds is not None:
- raise ValueError(
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
- " only forward one of the two."
- )
- elif prompt_2 is not None and prompt_embeds is not None:
- raise ValueError(
- f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
- " only forward one of the two."
- )
- elif prompt is None and prompt_embeds is None:
- raise ValueError(
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
- )
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
- elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
- raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
-
- if prompt_embeds is not None and pooled_prompt_embeds is None:
- raise ValueError(
- "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
- )
-
- if max_sequence_length is not None and max_sequence_length > 512:
- raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
-
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
@@ -517,10 +371,6 @@ def prepare_latents(
def guidance_scale(self):
return self._guidance_scale
- @property
- def joint_attention_kwargs(self):
- return self._joint_attention_kwargs
-
@property
def num_timesteps(self):
return self._num_timesteps
@@ -533,35 +383,38 @@ def interrupt(self):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
- prompt: Union[str, List[str]] = None,
- prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt: Union[str, List[str]],
+ input_images: Optional[Union[List[str], List[List[str]]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
- num_inference_steps: int = 28,
+ num_inference_steps: int = 50,
+ max_input_image_size: int = 1024,
timesteps: List[int] = None,
- guidance_scale: float = 3.5,
+ guidance_scale: float = 2.5,
+ img_guidance_scale: float = 1.6,
+ use_kv_cache: bool = True,
+ offload_kv_cache: bool = True,
+ offload_transformer_block: bool = False,
+ use_input_image_size_as_output: bool = False,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
- prompt_embeds: Optional[torch.FloatTensor] = None,
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- max_sequence_length: int = 512,
+ max_sequence_length: int = 120000,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
- instead.
- prompt_2 (`str` or `List[str]`, *optional*):
- The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
- will be used instead
+ The prompt or prompts to guide the image generation.
+ If the input includes images, need to add placeholders `<|image_i|>` in the prompt to indicate the position of the i-th images.
+ input_images (`List[str]` or `List[List[str]]`, *optional*):
+ The list of input images. We will replace the "<|image_i|>" in prompt with the i-th image in list.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -569,16 +422,28 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
+ max_input_image_size (`int`, *optional*, defaults to 1024):
+ the maximum size of input image, which will be used to crop the input image to the maximum size
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
- guidance_scale (`float`, *optional*, defaults to 7.0):
+ guidance_scale (`float`, *optional*, defaults to 2.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
+ img_guidance_scale (`float`, *optional*, defaults to 1.6):
+ Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
+ use_kv_cache (`bool`, *optional*, defaults to True):
+ enable kv cache to speed up the inference
+ offload_kv_cache (`bool`, *optional*, defaults to True):
+ offload the cached key and value to cpu, which can save memory but slow down the generation silightly
+ offload_transformer_block (`bool`, *optional*, defaults to False):
+ offload the transformer layers to cpu, which can save memory but slow down the generation
+ use_input_image_size_as_output (bool, defaults to False):
+ whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -588,18 +453,12 @@ def __call__(
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
- prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
- provided, text embeddings will be generated from `prompt` input argument.
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
- joint_attention_kwargs (`dict`, *optional*):
+ attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -616,123 +475,116 @@ def __call__(
Examples:
- Returns:
- [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
- is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
- images.
+ Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ input_images = [input_images]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
- prompt_2,
+ input_images,
height,
width,
- prompt_embeds=prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
- self._joint_attention_kwargs = joint_attention_kwargs
+ self._attention_kwargs = attention_kwargs
self._interrupt = False
# 2. Define call parameters
- if prompt is not None and isinstance(prompt, str):
- batch_size = 1
- elif prompt is not None and isinstance(prompt, list):
- batch_size = len(prompt)
- else:
- batch_size = prompt_embeds.shape[0]
-
+ batch_size = len(prompt)
device = self._execution_device
- lora_scale = (
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
- )
- (
- prompt_embeds,
- pooled_prompt_embeds,
- text_ids,
- ) = self.encode_prompt(
- prompt=prompt,
- prompt_2=prompt_2,
- prompt_embeds=prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- device=device,
- num_images_per_prompt=num_images_per_prompt,
- max_sequence_length=max_sequence_length,
- lora_scale=lora_scale,
+ # 3. process multi-modal instructions
+ if max_input_image_size != self.multimodal_processor.max_image_size:
+ self.processor = OmniGenMultiModalProcessor(self.text_tokenizer, max_image_size=max_input_image_size)
+ processed_data = self.processor(prompt,
+ input_images,
+ height=height,
+ width=width,
+ use_input_image_size_as_output=use_input_image_size_as_output)
+
+ # 4. Encode input images and obtain multi-modal conditional embeddings
+ input_img_latents = self.encod_input_iamges(processed_data['input_pixel_values'], device=device)
+ condition_tokens = self.get_multimodal_embeddings(input_ids=processed_data['input_ids'],
+ input_img_latents=input_img_latents,
+ input_image_sizes=processed_data['input_image_sizes'],
+ device=device,
+ )
+
+ # 5. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps,
)
- # 4. Prepare latent variables
- num_channels_latents = self.transformer.config.in_channels // 4
- latents, latent_image_ids = self.prepare_latents(
+ # 6. Prepare latents.
+ if use_input_image_size_as_output:
+ height, width = processed_data['input_pixel_values'][0].shape[-2:]
+ num_cfg = 2 if input_images is not None else 1
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
batch_size * num_images_per_prompt,
- num_channels_latents,
+ latent_channels,
height,
width,
- prompt_embeds.dtype,
+ condition_tokens.dtype,
device,
generator,
latents,
)
- # 5. Prepare timesteps
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
- image_seq_len = latents.shape[1]
- mu = calculate_shift(
- image_seq_len,
- self.scheduler.config.base_image_seq_len,
- self.scheduler.config.max_image_seq_len,
- self.scheduler.config.base_shift,
- self.scheduler.config.max_shift,
- )
- timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler,
- num_inference_steps,
- device,
- timesteps,
- sigmas,
- mu=mu,
- )
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
- self._num_timesteps = len(timesteps)
-
- # handle guidance
- if self.transformer.config.guidance_embeds:
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
- guidance = guidance.expand(latents.shape[0])
- else:
- guidance = None
+ # 7. Prepare OmniGenCache
+ num_tokens_for_output_img = latents.size(-1) * latents.size(-2) // (self.patch_size ** 2)
+ cache = OmniGenCache(num_tokens_for_output_img, offload_kv_cache) if use_kv_cache else None
+ self.transformer.llm.use_cache = use_kv_cache
- # 6. Denoising loop
+ # 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
- if self.interrupt:
- continue
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * (num_cfg+1))
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
-
- noise_pred = self.transformer(
- hidden_states=latents,
- timestep=timestep / 1000,
- guidance=guidance,
- pooled_projections=pooled_prompt_embeds,
- encoder_hidden_states=prompt_embeds,
- txt_ids=text_ids,
- img_ids=latent_image_ids,
- joint_attention_kwargs=self.joint_attention_kwargs,
+ timestep = t.expand(latent_model_input.shape[0])
+
+ noise_pred, cache = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ condition_tokens=condition_tokens,
+ attention_mask=processed_data['attention_mask'],
+ position_ids=processed_data['position_ids'],
+ attention_kwargs=attention_kwargs,
+ past_key_values=cache,
+ offload_transformer_block=offload_transformer_block,
+ return_past_key_values=True,
return_dict=False,
- )[0]
+ )
+
+ if use_kv_cache:
+ if condition_tokens is not None:
+ condition_tokens = None
+ processed_data['attention_mask'] = processed_data['attention_mask'][..., -(num_tokens_for_output_img+1):, :]
+ processed_data['position_ids'] = processed_data['position_ids'][:, -(num_tokens_for_output_img + 1):]
+
+ if num_cfg == 2:
+ cond, uncond, img_cond = torch.split(noise_pred, len(model_out) // 3, dim=0)
+ noise_pred = uncond + img_guidance_scale * (img_cond - uncond) + guidance_scale * (cond - img_cond)
+ else:
+ cond, uncond = torch.split(noise_pred, len(noise_pred) // 2, dim=0)
+ noise_pred = uncond + guidance_scale * (cond - uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
+ noise_pred = -noise_pred # OmniGen uses standard rectified flow instead of denoise, different from FLUX and SD3
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
@@ -740,30 +592,14 @@ def __call__(
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ progress_bar.update()
- # call the callback, if provided
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
- progress_bar.update()
-
- if XLA_AVAILABLE:
- xm.mark_step()
-
- if output_type == "latent":
- image = latents
-
- else:
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ if not output_type == "latent":
+ latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
+ else:
+ image = latents
# Offload all models
self.maybe_free_model_hooks()
@@ -771,4 +607,5 @@ def __call__(
if not return_dict:
return (image,)
- return FluxPipelineOutput(images=image)
+ return ImagePipelineOutput(images=image)
+
diff --git a/src/diffusers/pipelines/omnigen/pipeline_output.py b/src/diffusers/pipelines/omnigen/pipeline_output.py
deleted file mode 100644
index 40ff24199900..000000000000
--- a/src/diffusers/pipelines/omnigen/pipeline_output.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from dataclasses import dataclass
-from typing import List, Union
-
-import numpy as np
-import PIL.Image
-
-from ...utils import BaseOutput
-
-
-@dataclass
-class OmniGenPipelineOutput(BaseOutput):
- """
- Output class for Stable Diffusion pipelines.
-
- Args:
- images (`List[PIL.Image.Image]` or `np.ndarray`)
- List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
- num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
- """
-
- images: Union[List[PIL.Image.Image], np.ndarray]
-
-
-
diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py
new file mode 100644
index 000000000000..e6a36bcd8df7
--- /dev/null
+++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py
@@ -0,0 +1,295 @@
+import re
+from typing import Dict, List
+
+import numpy as np
+import torch
+from PIL import Image
+from torchvision import transforms
+
+
+
+def crop_image(pil_image, max_image_size):
+ """
+ Crop the image so that its height and width does not exceed `max_image_size`,
+ while ensuring both the height and width are multiples of 16.
+ """
+ while min(*pil_image.size) >= 2 * max_image_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ if max(*pil_image.size) > max_image_size:
+ scale = max_image_size / max(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ if min(*pil_image.size) < 16:
+ scale = 16 / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image)
+ crop_y1 = (arr.shape[0] % 16) // 2
+ crop_y2 = arr.shape[0] % 16 - crop_y1
+
+ crop_x1 = (arr.shape[1] % 16) // 2
+ crop_x2 = arr.shape[1] % 16 - crop_x1
+
+ arr = arr[crop_y1:arr.shape[0] - crop_y2, crop_x1:arr.shape[1] - crop_x2]
+ return Image.fromarray(arr)
+
+
+class OmniGenMultiModalProcessor:
+ def __init__(self,
+ text_tokenizer,
+ max_image_size: int = 1024):
+ self.text_tokenizer = text_tokenizer
+ self.max_image_size = max_image_size
+
+ self.image_transform = transforms.Compose([
+ transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+
+ self.collator = OmniGenCollator()
+
+ def process_image(self, image):
+ image = Image.open(image).convert('RGB')
+ return self.image_transform(image)
+
+ def process_multi_modal_prompt(self, text, input_images):
+ text = self.add_prefix_instruction(text)
+ if input_images is None or len(input_images) == 0:
+ model_inputs = self.text_tokenizer(text)
+ return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
+
+ pattern = r"<\|image_\d+\|>"
+ prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
+
+ for i in range(1, len(prompt_chunks)):
+ if prompt_chunks[i][0] == 1:
+ prompt_chunks[i] = prompt_chunks[i][1:]
+
+ image_tags = re.findall(pattern, text)
+ image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
+
+ unique_image_ids = sorted(list(set(image_ids)))
+ assert unique_image_ids == list(range(1,
+ len(unique_image_ids) + 1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
+ # total images must be the same as the number of image tags
+ assert len(unique_image_ids) == len(
+ input_images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
+
+ input_images = [input_images[x - 1] for x in image_ids]
+
+ all_input_ids = []
+ img_inx = []
+ idx = 0
+ for i in range(len(prompt_chunks)):
+ all_input_ids.extend(prompt_chunks[i])
+ if i != len(prompt_chunks) - 1:
+ start_inx = len(all_input_ids)
+ size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
+ img_inx.append([start_inx, start_inx + size])
+ all_input_ids.extend([0] * size)
+
+ return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
+
+ def add_prefix_instruction(self, prompt):
+ user_prompt = '<|user|>\n'
+ generation_prompt = 'Generate an image according to the following instructions\n'
+ assistant_prompt = '<|assistant|>\n<|diffusion|>'
+ prompt_suffix = "<|end|>\n"
+ prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
+ return prompt
+
+ def __call__(self,
+ instructions: List[str],
+ input_images: List[List[str]] = None,
+ height: int = 1024,
+ width: int = 1024,
+ negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
+ use_img_cfg: bool = True,
+ separate_cfg_input: bool = False,
+ use_input_image_size_as_output: bool = False,
+ ) -> Dict:
+
+ if input_images is None:
+ use_img_cfg = False
+ if isinstance(instructions, str):
+ instructions = [instructions]
+ input_images = [input_images]
+
+ input_data = []
+ for i in range(len(instructions)):
+ cur_instruction = instructions[i]
+ cur_input_images = None if input_images is None else input_images[i]
+ if cur_input_images is not None and len(cur_input_images) > 0:
+ cur_input_images = [self.process_image(x) for x in cur_input_images]
+ else:
+ cur_input_images = None
+ assert "<|image_1|>" not in cur_instruction
+
+ mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
+
+ neg_mllm_input, img_cfg_mllm_input = None, None
+ neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
+ if use_img_cfg:
+ if cur_input_images is not None and len(cur_input_images) >= 1:
+ img_cfg_prompt = [f"<|image_{i + 1}|>" for i in range(len(cur_input_images))]
+ img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
+ else:
+ img_cfg_mllm_input = neg_mllm_input
+
+ if use_input_image_size_as_output:
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input,
+ [mllm_input['pixel_values'][0].size(-2), mllm_input['pixel_values'][0].size(-1)]))
+ else:
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
+
+ return self.collator(input_data)
+
+
+class OmniGenCollator:
+ def __init__(self, pad_token_id=2, hidden_size=3072):
+ self.pad_token_id = pad_token_id
+ self.hidden_size = hidden_size
+
+ def create_position(self, attention_mask, num_tokens_for_output_images):
+ position_ids = []
+ text_length = attention_mask.size(-1)
+ img_length = max(num_tokens_for_output_images)
+ for mask in attention_mask:
+ temp_l = torch.sum(mask)
+ temp_position = [0] * (text_length - temp_l) + [i for i in range(
+ temp_l + img_length + 1)] # we add a time embedding into the sequence, so add one more token
+ position_ids.append(temp_position)
+ return torch.LongTensor(position_ids)
+
+ def create_mask(self, attention_mask, num_tokens_for_output_images):
+ """
+ OmniGen applies causal attention to each element in the sequence, but applies bidirectional attention within each image sequence
+ References: [OmniGen](https://arxiv.org/pdf/2409.11340)
+ """
+ extended_mask = []
+ padding_images = []
+ text_length = attention_mask.size(-1)
+ img_length = max(num_tokens_for_output_images)
+ seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
+ inx = 0
+ for mask in attention_mask:
+ temp_l = torch.sum(mask)
+ pad_l = text_length - temp_l
+
+ temp_mask = torch.tril(torch.ones(size=(temp_l + 1, temp_l + 1)))
+
+ image_mask = torch.zeros(size=(temp_l + 1, img_length))
+ temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
+
+ image_mask = torch.ones(size=(img_length, temp_l + img_length + 1))
+ temp_mask = torch.cat([temp_mask, image_mask], dim=0)
+
+ if pad_l > 0:
+ pad_mask = torch.zeros(size=(temp_l + 1 + img_length, pad_l))
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
+
+ pad_mask = torch.ones(size=(pad_l, seq_len))
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
+
+ true_img_length = num_tokens_for_output_images[inx]
+ pad_img_length = img_length - true_img_length
+ if pad_img_length > 0:
+ temp_mask[:, -pad_img_length:] = 0
+ temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
+ else:
+ temp_padding_imgs = None
+
+ extended_mask.append(temp_mask.unsqueeze(0))
+ padding_images.append(temp_padding_imgs)
+ inx += 1
+ return torch.cat(extended_mask, dim=0), padding_images
+
+ def adjust_attention_for_input_images(self, attention_mask, image_sizes):
+ for b_inx in image_sizes.keys():
+ for start_inx, end_inx in image_sizes[b_inx]:
+ attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
+
+ return attention_mask
+
+ def pad_input_ids(self, input_ids, image_sizes):
+ max_l = max([len(x) for x in input_ids])
+ padded_ids = []
+ attention_mask = []
+
+ for i in range(len(input_ids)):
+ temp_ids = input_ids[i]
+ temp_l = len(temp_ids)
+ pad_l = max_l - temp_l
+ if pad_l == 0:
+ attention_mask.append([1] * max_l)
+ padded_ids.append(temp_ids)
+ else:
+ attention_mask.append([0] * pad_l + [1] * temp_l)
+ padded_ids.append([self.pad_token_id] * pad_l + temp_ids)
+
+ if i in image_sizes:
+ new_inx = []
+ for old_inx in image_sizes[i]:
+ new_inx.append([x + pad_l for x in old_inx])
+ image_sizes[i] = new_inx
+
+ return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
+
+ def process_mllm_input(self, mllm_inputs, target_img_size):
+ num_tokens_for_output_images = []
+ for img_size in target_img_size:
+ num_tokens_for_output_images.append(img_size[0] * img_size[1] // 16 // 16)
+
+ pixel_values, image_sizes = [], {}
+ b_inx = 0
+ for x in mllm_inputs:
+ if x['pixel_values'] is not None:
+ pixel_values.extend(x['pixel_values'])
+ for size in x['image_sizes']:
+ if b_inx not in image_sizes:
+ image_sizes[b_inx] = [size]
+ else:
+ image_sizes[b_inx].append(size)
+ b_inx += 1
+ pixel_values = [x.unsqueeze(0) for x in pixel_values]
+
+ input_ids = [x['input_ids'] for x in mllm_inputs]
+ padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
+ position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
+ attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
+ attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
+
+ return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
+
+ def __call__(self, features):
+ mllm_inputs = [f[0] for f in features]
+ cfg_mllm_inputs = [f[1] for f in features]
+ img_cfg_mllm_input = [f[2] for f in features]
+ target_img_size = [f[3] for f in features]
+
+ if img_cfg_mllm_input[0] is not None:
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
+ target_img_size = target_img_size + target_img_size + target_img_size
+ else:
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs
+ target_img_size = target_img_size + target_img_size
+
+ all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(
+ mllm_inputs, target_img_size)
+
+ data = {"input_ids": all_padded_input_ids,
+ "attention_mask": all_attention_mask,
+ "position_ids": all_position_ids,
+ "input_pixel_values": all_pixel_values,
+ "input_image_sizes": all_image_sizes,
+ }
+ return data
+
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 5091ff318f1b..ce7da1bf8301 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -466,6 +466,20 @@ def from_config(cls, *args, **kwargs):
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class OmniGenTransformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
class PixArtTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index b76ea3824060..b1440d33a6f7 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -1142,6 +1142,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class OmniGenPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class PaintByExamplePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/test.py b/test.py
index e110b93a7b74..3648f7e9875e 100644
--- a/test.py
+++ b/test.py
@@ -3,7 +3,7 @@
from huggingface_hub import snapshot_download
-from diffusers.models import OmniGenTransformerModel
+from diffusers.models import OmniGenTransformer2DModel
from transformers import Phi3Model, Phi3Config
@@ -11,7 +11,7 @@
model_name = "Shitao/OmniGen-v1"
config = Phi3Config.from_pretrained("Shitao/OmniGen-v1")
-model = OmniGenTransformerModel(transformer_config=config)
+model = OmniGenTransformer2DModel(transformer_config=config)
cache_folder = os.getenv('HF_HUB_CACHE')
model_name = snapshot_download(repo_id=model_name,
cache_dir=cache_folder,
From 0d041944dcd766db468e0568f8e12c02a64964ac Mon Sep 17 00:00:00 2001
From: shitao <2906698981@qq.com>
Date: Mon, 2 Dec 2024 17:40:43 +0800
Subject: [PATCH 04/55] omnigen pipeline
---
scripts/convert_omnigen_to_diffusers.py | 22 ++++++++++++++++------
1 file changed, 16 insertions(+), 6 deletions(-)
diff --git a/scripts/convert_omnigen_to_diffusers.py b/scripts/convert_omnigen_to_diffusers.py
index 2b02f6a064dc..5d6d98bcb9a7 100644
--- a/scripts/convert_omnigen_to_diffusers.py
+++ b/scripts/convert_omnigen_to_diffusers.py
@@ -4,13 +4,25 @@
import torch
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer, AutoConfig
+from huggingface_hub import snapshot_download
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenTransformer2DModel, OmniGenPipeline
def main(args):
# checkpoint from https://huggingface.co/Shitao/OmniGen-v1
- ckpt = load_file(args.origin_ckpt_path, device="cpu")
+
+ if not os.path.exists(args.origin_ckpt_path):
+ print("Model not found, downloading...")
+ cache_folder = os.getenv('HF_HUB_CACHE')
+ args.origin_ckpt_path = snapshot_download(repo_id=args.origin_ckpt_path,
+ cache_dir=cache_folder,
+ ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5',
+ 'model.pt'])
+ print(f"Downloaded model to {args.origin_ckpt_path}")
+
+ ckpt = os.path.join(args.origin_ckpt_path, 'model.safetensors')
+ ckpt = load_file(ckpt, device="cpu")
mapping_dict = {
"pos_embed": "patch_embedding.pos_embed",
@@ -27,7 +39,6 @@ def main(args):
converted_state_dict = {}
for k, v in ckpt.items():
- # new_ckpt[k] = v
if k in mapping_dict:
converted_state_dict[mapping_dict[k]] = v
else:
@@ -35,7 +46,6 @@ def main(args):
transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path)
- # Lumina-Next-SFT 2B
transformer = OmniGenTransformer2DModel(
transformer_config=transformer_config,
patch_size=2,
@@ -49,7 +59,7 @@ def main(args):
scheduler = FlowMatchEulerDiscreteScheduler()
- vae = AutoencoderKL.from_pretrained(args.origin_ckpt_path, torch_dtype=torch.float32)
+ vae = AutoencoderKL.from_pretrained(os.path.join(args.origin_ckpt_path, "vae"), torch_dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path)
@@ -64,10 +74,10 @@ def main(args):
parser = argparse.ArgumentParser()
parser.add_argument(
- "--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
+ "--origin_ckpt_path", default="Shitao/OmniGen-v1", type=str, required=False, help="Path to the checkpoint to convert."
)
- parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
+ parser.add_argument("--dump_path", default="OmniGen-v1-diffusers", type=str, required=True, help="Path to the output pipeline.")
args = parser.parse_args()
main(args)
From 85abe5e5b077d7c49f220d0a78317cc413a6bf5c Mon Sep 17 00:00:00 2001
From: staoxiao <2906698981@qq.com>
Date: Tue, 3 Dec 2024 17:15:07 +0800
Subject: [PATCH 05/55] update omnigen_pipeline
---
scripts/convert_omnigen_to_diffusers.py | 141 ++++++++-
src/diffusers/models/embeddings.py | 4 +-
.../transformers/transformer_omnigen.py | 9 +-
src/diffusers/pipelines/__init__.py | 2 +
.../pipelines/omnigen/kvcache_omnigen.py | 45 +--
.../pipelines/omnigen/pipeline_omnigen.py | 108 +++----
.../pipelines/omnigen/processor_omnigen.py | 2 -
.../scheduling_flow_match_euler_discrete.py | 4 +-
test.py | 92 +++---
tests/pipelines/omnigen/__init__.py | 0
tests/pipelines/omnigen/test_pipeline_flux.py | 298 ++++++++++++++++++
11 files changed, 570 insertions(+), 135 deletions(-)
create mode 100644 tests/pipelines/omnigen/__init__.py
create mode 100644 tests/pipelines/omnigen/test_pipeline_flux.py
diff --git a/scripts/convert_omnigen_to_diffusers.py b/scripts/convert_omnigen_to_diffusers.py
index 5d6d98bcb9a7..8c9bc8bdb457 100644
--- a/scripts/convert_omnigen_to_diffusers.py
+++ b/scripts/convert_omnigen_to_diffusers.py
@@ -1,5 +1,6 @@
import argparse
import os
+os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2'
import torch
from safetensors.torch import load_file
@@ -44,8 +45,141 @@ def main(args):
else:
converted_state_dict[k] = v
- transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path)
-
+ # transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path)
+ # print(type(transformer_config.__dict__))
+ # print(transformer_config.__dict__)
+
+ transformer_config = {
+ "_name_or_path": "Phi-3-vision-128k-instruct",
+ "architectures": [
+ "Phi3ForCausalLM"
+ ],
+ "attention_dropout": 0.0,
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "hidden_act": "silu",
+ "hidden_size": 3072,
+ "initializer_range": 0.02,
+ "intermediate_size": 8192,
+ "max_position_embeddings": 131072,
+ "model_type": "phi3",
+ "num_attention_heads": 32,
+ "num_hidden_layers": 32,
+ "num_key_value_heads": 32,
+ "original_max_position_embeddings": 4096,
+ "rms_norm_eps": 1e-05,
+ "rope_scaling": {
+ "long_factor": [
+ 1.0299999713897705,
+ 1.0499999523162842,
+ 1.0499999523162842,
+ 1.0799999237060547,
+ 1.2299998998641968,
+ 1.2299998998641968,
+ 1.2999999523162842,
+ 1.4499999284744263,
+ 1.5999999046325684,
+ 1.6499998569488525,
+ 1.8999998569488525,
+ 2.859999895095825,
+ 3.68999981880188,
+ 5.419999599456787,
+ 5.489999771118164,
+ 5.489999771118164,
+ 9.09000015258789,
+ 11.579999923706055,
+ 15.65999984741211,
+ 15.769999504089355,
+ 15.789999961853027,
+ 18.360000610351562,
+ 21.989999771118164,
+ 23.079999923706055,
+ 30.009998321533203,
+ 32.35000228881836,
+ 32.590003967285156,
+ 35.56000518798828,
+ 39.95000457763672,
+ 53.840003967285156,
+ 56.20000457763672,
+ 57.95000457763672,
+ 59.29000473022461,
+ 59.77000427246094,
+ 59.920005798339844,
+ 61.190006256103516,
+ 61.96000671386719,
+ 62.50000762939453,
+ 63.3700065612793,
+ 63.48000717163086,
+ 63.48000717163086,
+ 63.66000747680664,
+ 63.850006103515625,
+ 64.08000946044922,
+ 64.760009765625,
+ 64.80001068115234,
+ 64.81001281738281,
+ 64.81001281738281
+ ],
+ "short_factor": [
+ 1.05,
+ 1.05,
+ 1.05,
+ 1.1,
+ 1.1,
+ 1.1,
+ 1.2500000000000002,
+ 1.2500000000000002,
+ 1.4000000000000004,
+ 1.4500000000000004,
+ 1.5500000000000005,
+ 1.8500000000000008,
+ 1.9000000000000008,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.1000000000000005,
+ 2.1000000000000005,
+ 2.2,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3999999999999995,
+ 2.3999999999999995,
+ 2.6499999999999986,
+ 2.6999999999999984,
+ 2.8999999999999977,
+ 2.9499999999999975,
+ 3.049999999999997,
+ 3.049999999999997,
+ 3.049999999999997
+ ],
+ "type": "su"
+ },
+ "rope_theta": 10000.0,
+ "sliding_window": 131072,
+ "tie_word_embeddings": False,
+ "torch_dtype": "bfloat16",
+ "transformers_version": "4.38.1",
+ "use_cache": True,
+ "vocab_size": 32064,
+ "_attn_implementation": "sdpa"
+ }
transformer = OmniGenTransformer2DModel(
transformer_config=transformer_config,
patch_size=2,
@@ -53,6 +187,7 @@ def main(args):
pos_embed_max_size=192,
)
transformer.load_state_dict(converted_state_dict, strict=True)
+ transformer.to(torch.bfloat16)
num_model_params = sum(p.numel() for p in transformer.parameters())
print(f"Total number of transformer parameters: {num_model_params}")
@@ -77,7 +212,7 @@ def main(args):
"--origin_ckpt_path", default="Shitao/OmniGen-v1", type=str, required=False, help="Path to the checkpoint to convert."
)
- parser.add_argument("--dump_path", default="OmniGen-v1-diffusers", type=str, required=True, help="Path to the output pipeline.")
+ parser.add_argument("--dump_path", default="/share/shitao/repos/OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline.")
args = parser.parse_args()
main(args)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 9681a84b8878..720a48f3f747 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -381,9 +381,9 @@ def forward(self,
height, width = latent.shape[-2:]
pos_embed = self.cropped_pos_embed(height, width)
latent = self.patch_embeddings(latent, is_input_image)
- latent = latent + pos_embed
+ patched_latents = latent + pos_embed
- return latent
+ return patched_latents
class LuminaPatchEmbed(nn.Module):
diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py
index 2fa97146f1d3..031fc014e77a 100644
--- a/src/diffusers/models/transformers/transformer_omnigen.py
+++ b/src/diffusers/models/transformers/transformer_omnigen.py
@@ -125,7 +125,7 @@ def forward(
)
use_cache = False
- # kept for BC (non `Cache` `past_key_values` inputs)
+ # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
@@ -240,7 +240,7 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
@register_to_config
def __init__(
self,
- transformer_config: Phi3Config,
+ transformer_config: Dict,
patch_size=2,
in_channels=4,
pos_embed_max_size: int = 192,
@@ -251,6 +251,7 @@ def __init__(
self.patch_size = patch_size
self.pos_embed_max_size = pos_embed_max_size
+ transformer_config = Phi3Config(**transformer_config)
hidden_size = transformer_config.hidden_size
self.patch_embedding = OmniGenPatchEmbed(patch_size=patch_size,
@@ -386,7 +387,7 @@ def forward(self,
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
- height, width = hidden_states.size(-2)
+ height, width = hidden_states.size()[-2:]
hidden_states = self.patch_embedding(hidden_states, is_input_image=False)
num_tokens_for_output_image = hidden_states.size(1)
@@ -405,7 +406,7 @@ def forward(self,
image_embedding = output[:, -num_tokens_for_output_image:]
time_emb = self.t_embedder(timestep, dtype=hidden_states.dtype)
- x = self.final_layer(image_embedding, time_emb)
+ x = self.proj_out(self.norm_out(image_embedding, temb=time_emb))
output = self.unpatchify(x, height, width)
if not return_dict:
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 5143b1114fd3..0b390444f5f2 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -254,6 +254,7 @@
)
_import_structure["mochi"] = ["MochiPipeline"]
_import_structure["musicldm"] = ["MusicLDMPipeline"]
+ _import_structure["omnigen"] = ["OmniGenPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
@@ -584,6 +585,7 @@
)
from .mochi import MochiPipeline
from .musicldm import MusicLDMPipeline
+ from .omnigen import OmniGenPipeline
from .pag import (
AnimateDiffPAGPipeline,
HunyuanDiTPAGPipeline,
diff --git a/src/diffusers/pipelines/omnigen/kvcache_omnigen.py b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py
index 0270292c130f..7f02588ce405 100644
--- a/src/diffusers/pipelines/omnigen/kvcache_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py
@@ -1,15 +1,20 @@
+from tqdm import tqdm
from typing import Optional, Dict, Any, Tuple, List
+import gc
import torch
-from transformers.cache_utils import DynamicCache
+from transformers.cache_utils import Cache, DynamicCache, OffloadedCache
+
class OmniGenCache(DynamicCache):
- def __init__(self,
- num_tokens_for_img: int, offload_kv_cache: bool = False) -> None:
+ def __init__(self,
+ num_tokens_for_img: int,
+ offload_kv_cache: bool=False) -> None:
if not torch.cuda.is_available():
- raise RuntimeError(
- "OmniGenCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!")
+ # print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!")
+ # offload_kv_cache = False
+ raise RuntimeError("OffloadedCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!")
super().__init__()
self.original_device = []
self.prefetch_stream = torch.cuda.Stream()
@@ -25,17 +30,19 @@ def prefetch_layer(self, layer_idx: int):
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
+
def evict_previous_layer(self, layer_idx: int):
"Moves the previous layer cache to the CPU"
if len(self) > 2:
# We do it on the default stream so it occurs after all earlier computations on these tensors are done
- if layer_idx == 0:
+ if layer_idx == 0:
prev_layer_idx = -1
else:
prev_layer_idx = (layer_idx - 1) % len(self)
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
+
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
if layer_idx < len(self):
@@ -44,12 +51,12 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx)
# Load current layer cache to its original device if not already there
- # original_device = self.original_device[layer_idx]
+ original_device = self.original_device[layer_idx]
# self.prefetch_stream.synchronize(original_device)
- self.prefetch_stream.synchronize()
+ torch.cuda.synchronize(self.prefetch_stream)
key_tensor = self.key_cache[layer_idx]
value_tensor = self.value_cache[layer_idx]
-
+
# Prefetch the next layer
self.prefetch_layer((layer_idx + 1) % len(self))
else:
@@ -58,13 +65,13 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
return (key_tensor, value_tensor)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
-
+
def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
@@ -85,13 +92,13 @@ def update(
raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
elif len(self.key_cache) == layer_idx:
# only cache the states for condition tokens
- key_states = key_states[..., :-(self.num_tokens_for_img + 1), :]
- value_states = value_states[..., :-(self.num_tokens_for_img + 1), :]
+ key_states = key_states[..., :-(self.num_tokens_for_img+1), :]
+ value_states = value_states[..., :-(self.num_tokens_for_img+1), :]
- # Update the number of seen tokens
+ # Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
-
+
self.key_cache.append(key_states)
self.value_cache.append(value_states)
self.original_device.append(key_states.device)
diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
index 9c9a67e53e9c..500a440f4203 100644
--- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
@@ -15,6 +15,7 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
+import numpy as np
import torch
from transformers import LlamaTokenizer
@@ -179,6 +180,7 @@ def encod_input_iamges(
self,
input_pixel_values: List[torch.Tensor],
device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
):
"""
get the continues embedding of input images by VAE
@@ -188,7 +190,7 @@ def encod_input_iamges(
Returns: torch.Tensor
"""
device = device or self._execution_device
- dtype = self.vae.dtype
+ dtype = dtype or self.vae.dtype
input_img_latents = []
for img in input_pixel_values:
@@ -215,13 +217,15 @@ def get_multimodal_embeddings(self,
"""
device = device or self._execution_device
+ input_img_latents = [x.to(self.transformer.dtype) for x in input_img_latents]
+
condition_tokens = None
if input_ids is not None:
- condition_tokens = self.transformer.llm.embed_tokens(input_ids.to(device))
+ condition_tokens = self.transformer.llm.embed_tokens(input_ids.to(device)).clone()
input_img_inx = 0
if input_img_latents is not None:
input_image_tokens = self.transformer.patch_embedding(input_img_latents,
- is_input_images=True)
+ is_input_image=True)
for b_inx in input_image_sizes.keys():
for start_inx, end_inx in input_image_sizes[b_inx]:
@@ -248,10 +252,11 @@ def check_inputs(
f"The number of prompts: {len(prompt)} does not match the number of input images: {len(input_images)}."
)
for i in range(len(input_images)):
- if not all(f"<|image_{k}|>" in prompt[i] for k in range(len(input_images[i]))):
- raise ValueError(
- f"prompt `{prompt[i]}` doesn't have enough placeholders for the input images `{input_images[i]}`"
- )
+ if input_images[i] is not None:
+ if not all(f"<|image_{k+1}|>" in prompt[i] for k in range(len(input_images[i]))):
+ raise ValueError(
+ f"prompt `{prompt[i]}` doesn't have enough placeholders for the input images `{input_images[i]}`"
+ )
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
@@ -279,30 +284,6 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
return latent_image_ids.to(device=device, dtype=dtype)
- @staticmethod
- def _pack_latents(latents, batch_size, num_channels_latents, height, width):
- latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
- latents = latents.permute(0, 2, 4, 1, 3, 5)
- latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
-
- return latents
-
- @staticmethod
- def _unpack_latents(latents, height, width, vae_scale_factor):
- batch_size, num_patches, channels = latents.shape
-
- # VAE applies 8x compression on images but we must also account for packing which requires
- # latent height and width to be divisible by 2.
- height = 2 * (int(height) // (vae_scale_factor * 2))
- width = 2 * (int(width) // (vae_scale_factor * 2))
-
- latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
- latents = latents.permute(0, 3, 1, 4, 2, 5)
-
- latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
-
- return latents
-
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
@@ -343,16 +324,15 @@ def prepare_latents(
generator,
latents=None,
):
- # VAE applies 8x compression on images but we must also account for packing which requires
- # latent height and width to be divisible by 2.
- height = 2 * (int(height) // (self.vae_scale_factor * 2))
- width = 2 * (int(width) // (self.vae_scale_factor * 2))
-
- shape = (batch_size, num_channels_latents, height, width)
-
if latents is not None:
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
- return latents.to(device=device, dtype=dtype), latent_image_ids
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -361,11 +341,8 @@ def prepare_latents(
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
- latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
-
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
- return latents, latent_image_ids
+ return latents
@property
def guidance_scale(self):
@@ -482,9 +459,14 @@ def __call__(
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
+ num_cfg = 2 if input_images is not None else 1
+ use_img_cfg = True if input_images is not None else False
if isinstance(prompt, str):
prompt = [prompt]
input_images = [input_images]
+
+ # using Float32 for the VAE doesn't take up much memory but can prevent potential black image outputs.
+ self.vae.to(torch.float32)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
@@ -506,12 +488,15 @@ def __call__(
# 3. process multi-modal instructions
if max_input_image_size != self.multimodal_processor.max_image_size:
- self.processor = OmniGenMultiModalProcessor(self.text_tokenizer, max_image_size=max_input_image_size)
- processed_data = self.processor(prompt,
- input_images,
- height=height,
- width=width,
- use_input_image_size_as_output=use_input_image_size_as_output)
+ self.multimodal_processor = OmniGenMultiModalProcessor(self.text_tokenizer, max_image_size=max_input_image_size)
+ processed_data = self.multimodal_processor(prompt,
+ input_images,
+ height=height,
+ width=width,
+ use_img_cfg=use_img_cfg,
+ use_input_image_size_as_output=use_input_image_size_as_output)
+ processed_data['attention_mask'] = processed_data['attention_mask'].to(device)
+ processed_data['position_ids'] = processed_data['position_ids'].to(device)
# 4. Encode input images and obtain multi-modal conditional embeddings
input_img_latents = self.encod_input_iamges(processed_data['input_pixel_values'], device=device)
@@ -522,14 +507,14 @@ def __call__(
)
# 5. Prepare timesteps
+ sigmas = np.linspace(1, 0, num_inference_steps+1)[:num_inference_steps]
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps,
+ self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
)
# 6. Prepare latents.
if use_input_image_size_as_output:
height, width = processed_data['input_pixel_values'][0].shape[-2:]
- num_cfg = 2 if input_images is not None else 1
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
@@ -542,10 +527,15 @@ def __call__(
latents,
)
+
+ generator = torch.Generator(device=device).manual_seed(0)
+ latents = torch.randn(1, 4, height//8, width//8, device=device, generator=generator).to(self.transformer.dtype)
+ # latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype)
+
# 7. Prepare OmniGenCache
- num_tokens_for_output_img = latents.size(-1) * latents.size(-2) // (self.patch_size ** 2)
+ num_tokens_for_output_img = latents.size(-1) * latents.size(-2) // (self.transformer.patch_size ** 2)
cache = OmniGenCache(num_tokens_for_output_img, offload_kv_cache) if use_kv_cache else None
- self.transformer.llm.use_cache = use_kv_cache
+ self.transformer.llm.config.use_cache = use_kv_cache
# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -565,18 +555,18 @@ def __call__(
attention_kwargs=attention_kwargs,
past_key_values=cache,
offload_transformer_block=offload_transformer_block,
- return_past_key_values=True,
return_dict=False,
)
-
+
+ # if use kv cache, don't need attention mask and position ids of condition tokens for next step
if use_kv_cache:
if condition_tokens is not None:
condition_tokens = None
- processed_data['attention_mask'] = processed_data['attention_mask'][..., -(num_tokens_for_output_img+1):, :]
+ processed_data['attention_mask'] = processed_data['attention_mask'][..., -(num_tokens_for_output_img + 1):, :] # +1 is for the timestep token
processed_data['position_ids'] = processed_data['position_ids'][:, -(num_tokens_for_output_img + 1):]
if num_cfg == 2:
- cond, uncond, img_cond = torch.split(noise_pred, len(model_out) // 3, dim=0)
+ cond, uncond, img_cond = torch.split(noise_pred, len(noise_pred) // 3, dim=0)
noise_pred = uncond + img_guidance_scale * (img_cond - uncond) + guidance_scale * (cond - img_cond)
else:
cond, uncond = torch.split(noise_pred, len(noise_pred) // 2, dim=0)
@@ -584,7 +574,6 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
- noise_pred = -noise_pred # OmniGen uses standard rectified flow instead of denoise, different from FLUX and SD3
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
@@ -595,6 +584,7 @@ def __call__(
progress_bar.update()
if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py
index e6a36bcd8df7..b425d84269df 100644
--- a/src/diffusers/pipelines/omnigen/processor_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py
@@ -117,8 +117,6 @@ def __call__(self,
use_input_image_size_as_output: bool = False,
) -> Dict:
- if input_images is None:
- use_img_cfg = False
if isinstance(instructions, str):
instructions = [instructions]
input_images = [input_images]
diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
index c1096dbe0c29..7e01160b8626 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
@@ -212,7 +212,7 @@ def set_timesteps(
else:
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
- self.timesteps = timesteps.to(device=device)
+ self.timesteps = timesteps.to(device=device)
self.sigmas = sigmas
self._step_index = None
self._begin_index = None
@@ -300,7 +300,7 @@ def step(
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
-
+
prev_sample = sample + (sigma_next - sigma) * model_output
# Cast sample back to model compatible dtype
diff --git a/test.py b/test.py
index 3648f7e9875e..fe1410122acd 100644
--- a/test.py
+++ b/test.py
@@ -1,50 +1,54 @@
-import os
-os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2'
-
-from huggingface_hub import snapshot_download
-
-from diffusers.models import OmniGenTransformer2DModel
-from transformers import Phi3Model, Phi3Config
-
-
-from safetensors.torch import load_file
-
-model_name = "Shitao/OmniGen-v1"
-config = Phi3Config.from_pretrained("Shitao/OmniGen-v1")
-model = OmniGenTransformer2DModel(transformer_config=config)
-cache_folder = os.getenv('HF_HUB_CACHE')
-model_name = snapshot_download(repo_id=model_name,
- cache_dir=cache_folder,
- ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
-print(model_name)
-model_path = os.path.join(model_name, 'model.safetensors')
-ckpt = load_file(model_path, 'cpu')
-
-
-mapping_dict = {
- "pos_embed": "patch_embedding.pos_embed",
- "x_embedder.proj.weight": "patch_embedding.output_image_proj.weight",
- "x_embedder.proj.bias": "patch_embedding.output_image_proj.bias",
- "input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight",
- "input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias",
- "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
- "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
- "final_layer.linear.weight": "proj_out.weight",
- "final_layer.linear.bias": "proj_out.bias",
-
-}
-
-new_ckpt = {}
-for k, v in ckpt.items():
- # new_ckpt[k] = v
- if k in mapping_dict:
- new_ckpt[mapping_dict[k]] = v
- else:
- new_ckpt[k] = v
+# import os
+# os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2'
+
+# from huggingface_hub import snapshot_download
+
+# from diffusers.models import OmniGenTransformer2DModel
+# from transformers import Phi3Model, Phi3Config
+
+
+# from safetensors.torch import load_file
+
+# model_name = "Shitao/OmniGen-v1"
+# config = Phi3Config.from_pretrained("Shitao/OmniGen-v1")
+# model = OmniGenTransformer2DModel(transformer_config=config)
+# cache_folder = os.getenv('HF_HUB_CACHE')
+# model_name = snapshot_download(repo_id=model_name,
+# cache_dir=cache_folder,
+# ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
+# print(model_name)
+# model_path = os.path.join(model_name, 'model.safetensors')
+# ckpt = load_file(model_path, 'cpu')
+
+
+# mapping_dict = {
+# "pos_embed": "patch_embedding.pos_embed",
+# "x_embedder.proj.weight": "patch_embedding.output_image_proj.weight",
+# "x_embedder.proj.bias": "patch_embedding.output_image_proj.bias",
+# "input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight",
+# "input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias",
+# "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
+# "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
+# "final_layer.linear.weight": "proj_out.weight",
+# "final_layer.linear.bias": "proj_out.bias",
+
+# }
+
+# new_ckpt = {}
+# for k, v in ckpt.items():
+# # new_ckpt[k] = v
+# if k in mapping_dict:
+# new_ckpt[mapping_dict[k]] = v
+# else:
+# new_ckpt[k] = v
-model.load_state_dict(new_ckpt)
+# model.load_state_dict(new_ckpt)
+
+
+
+
diff --git a/tests/pipelines/omnigen/__init__.py b/tests/pipelines/omnigen/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/omnigen/test_pipeline_flux.py b/tests/pipelines/omnigen/test_pipeline_flux.py
new file mode 100644
index 000000000000..df9021ee0adb
--- /dev/null
+++ b/tests/pipelines/omnigen/test_pipeline_flux.py
@@ -0,0 +1,298 @@
+import gc
+import unittest
+
+import numpy as np
+import pytest
+import torch
+from huggingface_hub import hf_hub_download
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
+from diffusers.utils.testing_utils import (
+ numpy_cosine_similarity_distance,
+ require_big_gpu_with_torch_cuda,
+ slow,
+ torch_device,
+)
+
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+ check_qkv_fusion_matches_attn_procs_length,
+ check_qkv_fusion_processors_exist,
+)
+
+
+class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = FluxPipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
+ batch_params = frozenset(["prompt"])
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = FluxTransformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ num_layers=1,
+ num_single_layers=1,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ pooled_projection_dim=32,
+ axes_dims_rope=[4, 4, 8],
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModel(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 48,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_flux_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt_2"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # Outputs should be different here
+ # For some reasons, they don't show large differences
+ assert max_diff > 1e-6
+
+ def test_flux_prompt_embeds(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ output_with_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ prompt = inputs.pop("prompt")
+
+ (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
+ prompt,
+ prompt_2=None,
+ device=torch_device,
+ max_sequence_length=inputs["max_sequence_length"],
+ )
+ output_with_embeds = pipe(
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ **inputs,
+ ).images[0]
+
+ max_diff = np.abs(output_with_prompt - output_with_embeds).max()
+ assert max_diff < 1e-4
+
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ original_image_slice = image[0, -3:, -3:, -1]
+
+ # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
+ # to the pipeline level.
+ pipe.transformer.fuse_qkv_projections()
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_matches_attn_procs_length(
+ pipe.transformer, pipe.transformer.original_attn_processors
+ ), "Something wrong with the attention processors concerning the fused QKV projections."
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_fused = image[0, -3:, -3:, -1]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_disabled = image[0, -3:, -3:, -1]
+
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
+
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
+
+
+@slow
+@require_big_gpu_with_torch_cuda
+@pytest.mark.big_gpu_with_torch_cuda
+class FluxPipelineSlowTests(unittest.TestCase):
+ pipeline_class = FluxPipeline
+ repo_id = "black-forest-labs/FLUX.1-schnell"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def get_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ prompt_embeds = torch.load(
+ hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
+ )
+ pooled_prompt_embeds = torch.load(
+ hf_hub_download(
+ repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt"
+ )
+ )
+ return {
+ "prompt_embeds": prompt_embeds,
+ "pooled_prompt_embeds": pooled_prompt_embeds,
+ "num_inference_steps": 2,
+ "guidance_scale": 0.0,
+ "max_sequence_length": 256,
+ "output_type": "np",
+ "generator": generator,
+ }
+
+ def test_flux_inference(self):
+ pipe = self.pipeline_class.from_pretrained(
+ self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
+ )
+ pipe.enable_model_cpu_offload()
+
+ inputs = self.get_inputs(torch_device)
+
+ image = pipe(**inputs).images[0]
+ image_slice = image[0, :10, :10]
+ expected_slice = np.array(
+ [
+ 0.3242,
+ 0.3203,
+ 0.3164,
+ 0.3164,
+ 0.3125,
+ 0.3125,
+ 0.3281,
+ 0.3242,
+ 0.3203,
+ 0.3301,
+ 0.3262,
+ 0.3242,
+ 0.3281,
+ 0.3242,
+ 0.3203,
+ 0.3262,
+ 0.3262,
+ 0.3164,
+ 0.3262,
+ 0.3281,
+ 0.3184,
+ 0.3281,
+ 0.3281,
+ 0.3203,
+ 0.3281,
+ 0.3281,
+ 0.3164,
+ 0.3320,
+ 0.3320,
+ 0.3203,
+ ],
+ dtype=np.float32,
+ )
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
+
+ assert max_diff < 1e-4
From db92c6996d63417da5d8dcd4c97a5c16c7167f50 Mon Sep 17 00:00:00 2001
From: shitao <2906698981@qq.com>
Date: Tue, 3 Dec 2024 21:37:16 +0800
Subject: [PATCH 06/55] test case for omnigen
---
docs/source/en/api/pipelines/omnigen.md | 100 ++++++
.../pipelines/omnigen/pipeline_omnigen.py | 10 +-
.../pipelines/omnigen/processor_omnigen.py | 3 +-
tests/pipelines/omnigen/test_pipeline_flux.py | 298 ------------------
.../omnigen/test_pipeline_omnigen.py | 269 ++++++++++++++++
5 files changed, 374 insertions(+), 306 deletions(-)
create mode 100644 docs/source/en/api/pipelines/omnigen.md
delete mode 100644 tests/pipelines/omnigen/test_pipeline_flux.py
create mode 100644 tests/pipelines/omnigen/test_pipeline_omnigen.py
diff --git a/docs/source/en/api/pipelines/omnigen.md b/docs/source/en/api/pipelines/omnigen.md
new file mode 100644
index 000000000000..fe0bc5d496a0
--- /dev/null
+++ b/docs/source/en/api/pipelines/omnigen.md
@@ -0,0 +1,100 @@
+
+
+# OmniGen
+
+[OmniGen: Unified Image Generation](https://arxiv.org/pdf/2409.11340) from BAAI, by Shitao Xiao, Yueze Wang, Junjie Zhou, Huaying Yuan, Xingrun Xing, Ruiran Yan, Chaofan Li, Shuting Wang, Tiejun Huang, Zheng Liu.
+
+The abstract from the paper is:
+
+*The emergence of Large Language Models (LLMs) has unified language
+generation tasks and revolutionized human-machine interaction.
+However, in the realm of image generation, a unified model capable of handling various tasks
+within a single framework remains largely unexplored. In
+this work, we introduce OmniGen, a new diffusion model
+for unified image generation. OmniGen is characterized
+by the following features: 1) Unification: OmniGen not
+only demonstrates text-to-image generation capabilities but
+also inherently supports various downstream tasks, such
+as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of
+OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion
+models, it is more user-friendly and can complete complex
+tasks end-to-end through instructions without the need for
+extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from
+learning in a unified format, OmniGen effectively transfers
+knowledge across different tasks, manages unseen tasks and
+domains, and exhibits novel capabilities. We also explore
+the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. This work represents the first attempt at a general-purpose image generation model, and we will release our resources at https:
+//github.com/VectorSpaceLab/OmniGen to foster future advancements.*
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+This pipeline was contributed by [staoxiao](https://github.com/staoxiao). The original codebase can be found [here](https://github.com/VectorSpaceLab/OmniGen). The original weights can be found under [hf.co/shitao](https://huggingface.co/Shitao/OmniGen-v1).
+
+
+## Inference
+
+Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
+
+First, load the pipeline:
+
+```python
+import torch
+from diffusers import CogVideoXPipeline, CogVideoXImageToVideoPipeline
+from diffusers.utils import export_to_video,load_image
+pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b").to("cuda") # or "THUDM/CogVideoX-2b"
+```
+
+If you are using the image-to-video pipeline, load it as follows:
+
+```python
+pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V").to("cuda")
+```
+
+Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`:
+
+```python
+pipe.transformer.to(memory_format=torch.channels_last)
+```
+
+Compile the components and run inference:
+
+```python
+pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
+
+# CogVideoX works well with long and well-described prompts
+prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
+video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+```
+
+The [T2V benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are:
+
+```
+Without torch.compile(): Average inference time: 96.89 seconds.
+With torch.compile(): Average inference time: 76.27 seconds.
+```
+
+
+## CogVideoXPipeline
+
+[[autodoc]] CogVideoXPipeline
+ - all
+ - __call__
+
+
diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
index 500a440f4203..65156d2935dc 100644
--- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
@@ -15,6 +15,7 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
+import PIL
import numpy as np
import torch
from transformers import LlamaTokenizer
@@ -146,7 +147,7 @@ class OmniGenPipeline(
model_cpu_offload_seq = "transformer->vae"
_optional_components = []
- _callback_tensor_inputs = ["latents", "prompt_embeds"]
+ _callback_tensor_inputs = ["latents", "condition_tokens"]
def __init__(
self,
@@ -361,7 +362,7 @@ def interrupt(self):
def __call__(
self,
prompt: Union[str, List[str]],
- input_images: Optional[Union[List[str], List[List[str]]]] = None,
+ input_images: Optional[Union[List[str], List[PIL.Image.Image], List[List[str]], List[List[PIL.Image.Image]]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
@@ -527,11 +528,6 @@ def __call__(
latents,
)
-
- generator = torch.Generator(device=device).manual_seed(0)
- latents = torch.randn(1, 4, height//8, width//8, device=device, generator=generator).to(self.transformer.dtype)
- # latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype)
-
# 7. Prepare OmniGenCache
num_tokens_for_output_img = latents.size(-1) * latents.size(-2) // (self.transformer.patch_size ** 2)
cache = OmniGenCache(num_tokens_for_output_img, offload_kv_cache) if use_kv_cache else None
diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py
index b425d84269df..545c9b001c7d 100644
--- a/src/diffusers/pipelines/omnigen/processor_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py
@@ -57,7 +57,8 @@ def __init__(self,
self.collator = OmniGenCollator()
def process_image(self, image):
- image = Image.open(image).convert('RGB')
+ if isinstance(image, str):
+ image = Image.open(image).convert('RGB')
return self.image_transform(image)
def process_multi_modal_prompt(self, text, input_images):
diff --git a/tests/pipelines/omnigen/test_pipeline_flux.py b/tests/pipelines/omnigen/test_pipeline_flux.py
deleted file mode 100644
index df9021ee0adb..000000000000
--- a/tests/pipelines/omnigen/test_pipeline_flux.py
+++ /dev/null
@@ -1,298 +0,0 @@
-import gc
-import unittest
-
-import numpy as np
-import pytest
-import torch
-from huggingface_hub import hf_hub_download
-from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
-
-from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
-from diffusers.utils.testing_utils import (
- numpy_cosine_similarity_distance,
- require_big_gpu_with_torch_cuda,
- slow,
- torch_device,
-)
-
-from ..test_pipelines_common import (
- PipelineTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
-)
-
-
-class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
- pipeline_class = FluxPipeline
- params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
- batch_params = frozenset(["prompt"])
-
- # there is no xformers processor for Flux
- test_xformers_attention = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- transformer = FluxTransformer2DModel(
- patch_size=1,
- in_channels=4,
- num_layers=1,
- num_single_layers=1,
- attention_head_dim=16,
- num_attention_heads=2,
- joint_attention_dim=32,
- pooled_projection_dim=32,
- axes_dims_rope=[4, 4, 8],
- )
- clip_text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- hidden_act="gelu",
- projection_dim=32,
- )
-
- torch.manual_seed(0)
- text_encoder = CLIPTextModel(clip_text_encoder_config)
-
- torch.manual_seed(0)
- text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
-
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
-
- torch.manual_seed(0)
- vae = AutoencoderKL(
- sample_size=32,
- in_channels=3,
- out_channels=3,
- block_out_channels=(4,),
- layers_per_block=1,
- latent_channels=1,
- norm_num_groups=1,
- use_quant_conv=False,
- use_post_quant_conv=False,
- shift_factor=0.0609,
- scaling_factor=1.5035,
- )
-
- scheduler = FlowMatchEulerDiscreteScheduler()
-
- return {
- "scheduler": scheduler,
- "text_encoder": text_encoder,
- "text_encoder_2": text_encoder_2,
- "tokenizer": tokenizer,
- "tokenizer_2": tokenizer_2,
- "transformer": transformer,
- "vae": vae,
- }
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device="cpu").manual_seed(seed)
-
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 5.0,
- "height": 8,
- "width": 8,
- "max_sequence_length": 48,
- "output_type": "np",
- }
- return inputs
-
- def test_flux_different_prompts(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
-
- inputs = self.get_dummy_inputs(torch_device)
- output_same_prompt = pipe(**inputs).images[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt_2"] = "a different prompt"
- output_different_prompts = pipe(**inputs).images[0]
-
- max_diff = np.abs(output_same_prompt - output_different_prompts).max()
-
- # Outputs should be different here
- # For some reasons, they don't show large differences
- assert max_diff > 1e-6
-
- def test_flux_prompt_embeds(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
- inputs = self.get_dummy_inputs(torch_device)
-
- output_with_prompt = pipe(**inputs).images[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = inputs.pop("prompt")
-
- (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
- prompt,
- prompt_2=None,
- device=torch_device,
- max_sequence_length=inputs["max_sequence_length"],
- )
- output_with_embeds = pipe(
- prompt_embeds=prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- **inputs,
- ).images[0]
-
- max_diff = np.abs(output_with_prompt - output_with_embeds).max()
- assert max_diff < 1e-4
-
- def test_fused_qkv_projections(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = pipe(**inputs).images
- original_image_slice = image[0, -3:, -3:, -1]
-
- # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
- # to the pipeline level.
- pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
-
- inputs = self.get_dummy_inputs(device)
- image = pipe(**inputs).images
- image_slice_fused = image[0, -3:, -3:, -1]
-
- pipe.transformer.unfuse_qkv_projections()
- inputs = self.get_dummy_inputs(device)
- image = pipe(**inputs).images
- image_slice_disabled = image[0, -3:, -3:, -1]
-
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
-
- def test_flux_image_output_shape(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
- inputs = self.get_dummy_inputs(torch_device)
-
- height_width_pairs = [(32, 32), (72, 57)]
- for height, width in height_width_pairs:
- expected_height = height - height % (pipe.vae_scale_factor * 2)
- expected_width = width - width % (pipe.vae_scale_factor * 2)
-
- inputs.update({"height": height, "width": width})
- image = pipe(**inputs).images[0]
- output_height, output_width, _ = image.shape
- assert (output_height, output_width) == (expected_height, expected_width)
-
-
-@slow
-@require_big_gpu_with_torch_cuda
-@pytest.mark.big_gpu_with_torch_cuda
-class FluxPipelineSlowTests(unittest.TestCase):
- pipeline_class = FluxPipeline
- repo_id = "black-forest-labs/FLUX.1-schnell"
-
- def setUp(self):
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def get_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device="cpu").manual_seed(seed)
-
- prompt_embeds = torch.load(
- hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
- )
- pooled_prompt_embeds = torch.load(
- hf_hub_download(
- repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt"
- )
- )
- return {
- "prompt_embeds": prompt_embeds,
- "pooled_prompt_embeds": pooled_prompt_embeds,
- "num_inference_steps": 2,
- "guidance_scale": 0.0,
- "max_sequence_length": 256,
- "output_type": "np",
- "generator": generator,
- }
-
- def test_flux_inference(self):
- pipe = self.pipeline_class.from_pretrained(
- self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
- )
- pipe.enable_model_cpu_offload()
-
- inputs = self.get_inputs(torch_device)
-
- image = pipe(**inputs).images[0]
- image_slice = image[0, :10, :10]
- expected_slice = np.array(
- [
- 0.3242,
- 0.3203,
- 0.3164,
- 0.3164,
- 0.3125,
- 0.3125,
- 0.3281,
- 0.3242,
- 0.3203,
- 0.3301,
- 0.3262,
- 0.3242,
- 0.3281,
- 0.3242,
- 0.3203,
- 0.3262,
- 0.3262,
- 0.3164,
- 0.3262,
- 0.3281,
- 0.3184,
- 0.3281,
- 0.3281,
- 0.3203,
- 0.3281,
- 0.3281,
- 0.3164,
- 0.3320,
- 0.3320,
- 0.3203,
- ],
- dtype=np.float32,
- )
-
- max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
-
- assert max_diff < 1e-4
diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py
new file mode 100644
index 000000000000..73124576ee04
--- /dev/null
+++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py
@@ -0,0 +1,269 @@
+import gc
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM
+
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenTransformer2DModel, OmniGenPipeline
+from diffusers.utils.testing_utils import (
+ numpy_cosine_similarity_distance,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = OmniGenPipeline
+ params = frozenset(
+ [
+ "prompt",
+ "guidance_scale",
+ ]
+ )
+ batch_params = frozenset(["prompt", ])
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+
+ transformer_config = {
+ "_name_or_path": "Phi-3-vision-128k-instruct",
+ "architectures": [
+ "Phi3ForCausalLM"
+ ],
+ "attention_dropout": 0.0,
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "hidden_act": "silu",
+ "hidden_size": 3072,
+ "initializer_range": 0.02,
+ "intermediate_size": 8192,
+ "max_position_embeddings": 131072,
+ "model_type": "phi3",
+ "num_attention_heads": 32,
+ "num_hidden_layers": 32,
+ "num_key_value_heads": 32,
+ "original_max_position_embeddings": 4096,
+ "rms_norm_eps": 1e-05,
+ "rope_scaling": {
+ "long_factor": [
+ 1.0299999713897705,
+ 1.0499999523162842,
+ 1.0499999523162842,
+ 1.0799999237060547,
+ 1.2299998998641968,
+ 1.2299998998641968,
+ 1.2999999523162842,
+ 1.4499999284744263,
+ 1.5999999046325684,
+ 1.6499998569488525,
+ 1.8999998569488525,
+ 2.859999895095825,
+ 3.68999981880188,
+ 5.419999599456787,
+ 5.489999771118164,
+ 5.489999771118164,
+ 9.09000015258789,
+ 11.579999923706055,
+ 15.65999984741211,
+ 15.769999504089355,
+ 15.789999961853027,
+ 18.360000610351562,
+ 21.989999771118164,
+ 23.079999923706055,
+ 30.009998321533203,
+ 32.35000228881836,
+ 32.590003967285156,
+ 35.56000518798828,
+ 39.95000457763672,
+ 53.840003967285156,
+ 56.20000457763672,
+ 57.95000457763672,
+ 59.29000473022461,
+ 59.77000427246094,
+ 59.920005798339844,
+ 61.190006256103516,
+ 61.96000671386719,
+ 62.50000762939453,
+ 63.3700065612793,
+ 63.48000717163086,
+ 63.48000717163086,
+ 63.66000747680664,
+ 63.850006103515625,
+ 64.08000946044922,
+ 64.760009765625,
+ 64.80001068115234,
+ 64.81001281738281,
+ 64.81001281738281
+ ],
+ "short_factor": [
+ 1.05,
+ 1.05,
+ 1.05,
+ 1.1,
+ 1.1,
+ 1.1,
+ 1.2500000000000002,
+ 1.2500000000000002,
+ 1.4000000000000004,
+ 1.4500000000000004,
+ 1.5500000000000005,
+ 1.8500000000000008,
+ 1.9000000000000008,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.1000000000000005,
+ 2.1000000000000005,
+ 2.2,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3999999999999995,
+ 2.3999999999999995,
+ 2.6499999999999986,
+ 2.6999999999999984,
+ 2.8999999999999977,
+ 2.9499999999999975,
+ 3.049999999999997,
+ 3.049999999999997,
+ 3.049999999999997
+ ],
+ "type": "su"
+ },
+ "rope_theta": 10000.0,
+ "sliding_window": 131072,
+ "tie_word_embeddings": False,
+ "torch_dtype": "bfloat16",
+ "transformers_version": "4.38.1",
+ "use_cache": True,
+ "vocab_size": 32064,
+ "_attn_implementation": "sdpa"
+ }
+ transformer = OmniGenTransformer2DModel(
+ transformer_config=transformer_config,
+ patch_size=2,
+ in_channels=4,
+ pos_embed_max_size=192,
+ )
+
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL()
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
+
+ components = {
+ "transformer": transformer.eval(),
+ "vae": vae.eval(),
+ "scheduler": scheduler,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "output_type": "np",
+ "height": 16,
+ "width": 16,
+ }
+ return inputs
+
+ def test_inference(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ generated_image = pipe(**inputs).images[0]
+
+ self.assertEqual(generated_image.shape, (1, 3, 16, 16))
+
+
+
+
+@slow
+@require_torch_gpu
+class OmniGenPipelineSlowTests(unittest.TestCase):
+ pipeline_class = OmniGenPipeline
+ repo_id = "Shitao/OmniGen-v1-diffusers"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def get_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ return {
+ "prompt": "A photo of a cat",
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "output_type": "np",
+ "generator": generator,
+ }
+
+ def test_omnigen_inference(self):
+ pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
+ pipe.enable_model_cpu_offload()
+
+ inputs = self.get_inputs(torch_device)
+
+ image = pipe(**inputs).images[0]
+ image_slice = image[0, :10, :10]
+ expected_slice = np.array(
+ [
+ [0.17773438, 0.18554688, 0.22070312],
+ [0.046875, 0.06640625, 0.10351562],
+ [0.0, 0.0, 0.02148438],
+ [0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0],
+ ],
+ dtype=np.float32,
+ )
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
+
+ assert max_diff < 1e-4
From 308766cf99cdc85464fba593b97631e07f54adc2 Mon Sep 17 00:00:00 2001
From: staoxiao <2906698981@qq.com>
Date: Wed, 4 Dec 2024 15:24:03 +0800
Subject: [PATCH 07/55] update omnigenpipeline
---
.../transformers/transformer_omnigen.py | 56 +++++++++++-----
.../pipelines/omnigen/pipeline_omnigen.py | 67 ++++++-------------
test.py | 11 +--
.../omnigen/test_pipeline_omnigen.py | 43 +++++++-----
4 files changed, 93 insertions(+), 84 deletions(-)
diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py
index 031fc014e77a..64790a2024e5 100644
--- a/src/diffusers/models/transformers/transformer_omnigen.py
+++ b/src/diffusers/models/transformers/transformer_omnigen.py
@@ -281,20 +281,6 @@ def unpatchify(self, x, h, w):
imgs = x.reshape(shape=(x.shape[0], c, h, w))
return imgs
- def prepare_condition_embeddings(self, input_ids, input_img_latents, input_image_sizes):
- condition_embeds = None
- if input_img_latents is not None:
- input_latents = self.patch_embedding(input_img_latents, is_input_images=True)
- if input_ids is not None:
- condition_embeds = self.llm.embed_tokens(input_ids).clone()
- input_img_inx = 0
- for b_inx in input_image_sizes.keys():
- for start_inx, end_inx in input_image_sizes[b_inx]:
- condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
- input_img_inx += 1
- if input_img_latents is not None:
- assert input_img_inx == len(input_latents)
- return condition_embeds
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
@@ -359,11 +345,46 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
+
+ def get_multimodal_embeddings(self,
+ input_ids: torch.Tensor,
+ input_img_latents: List[torch.Tensor],
+ input_image_sizes: Dict,
+ ):
+ """
+ get the multi-modal conditional embeddings
+ Args:
+ input_ids: a sequence of text id
+ input_img_latents: continues embedding of input images
+ input_image_sizes: the index of the input image in the input_ids sequence.
+
+ Returns: torch.Tensor
+
+ """
+ input_img_latents = [x.to(self.dtype) for x in input_img_latents]
+ condition_tokens = None
+ if input_ids is not None:
+ condition_tokens = self.llm.embed_tokens(input_ids)
+ input_img_inx = 0
+ if input_img_latents is not None:
+ input_image_tokens = self.patch_embedding(input_img_latents,
+ is_input_image=True)
+
+ for b_inx in input_image_sizes.keys():
+ for start_inx, end_inx in input_image_sizes[b_inx]:
+ # replace the placeholder in text tokens with the image embedding.
+ condition_tokens[b_inx, start_inx: end_inx] = input_image_tokens[input_img_inx].to(
+ condition_tokens.dtype)
+ input_img_inx += 1
+
+ return condition_tokens
def forward(self,
hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
- condition_tokens: torch.Tensor,
+ input_ids: torch.Tensor,
+ input_img_latents: List[torch.Tensor],
+ input_image_sizes: Dict[int, List[int]],
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: DynamicCache = None,
@@ -386,13 +407,16 @@ def forward(self,
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
-
height, width = hidden_states.size()[-2:]
hidden_states = self.patch_embedding(hidden_states, is_input_image=False)
num_tokens_for_output_image = hidden_states.size(1)
time_token = self.time_token(timestep, dtype=hidden_states.dtype).unsqueeze(1)
+ condition_tokens = self.get_multimodal_embeddings(input_ids=input_ids,
+ input_img_latents=input_img_latents,
+ input_image_sizes=input_image_sizes,
+ )
if condition_tokens is not None:
input_emb = torch.cat([condition_tokens, time_token, hidden_states], dim=1)
else:
diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
index 65156d2935dc..ff10cade50f3 100644
--- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
@@ -147,7 +147,7 @@ class OmniGenPipeline(
model_cpu_offload_seq = "transformer->vae"
_optional_components = []
- _callback_tensor_inputs = ["latents", "condition_tokens"]
+ _callback_tensor_inputs = ["latents", "input_images_latents"]
def __init__(
self,
@@ -199,43 +199,6 @@ def encod_input_iamges(
input_img_latents.append(img)
return input_img_latents
- def get_multimodal_embeddings(self,
- input_ids: torch.Tensor,
- input_img_latents: List[torch.Tensor],
- input_image_sizes: Dict,
- device: Optional[torch.device] = None,
- ):
- """
- get the multi-modal conditional embeddings
- Args:
- input_ids: a sequence of text id
- input_img_latents: continues embedding of input images
- input_image_sizes: the index of the input image in the input_ids sequence.
- device:
-
- Returns: torch.Tensor
-
- """
- device = device or self._execution_device
-
- input_img_latents = [x.to(self.transformer.dtype) for x in input_img_latents]
-
- condition_tokens = None
- if input_ids is not None:
- condition_tokens = self.transformer.llm.embed_tokens(input_ids.to(device)).clone()
- input_img_inx = 0
- if input_img_latents is not None:
- input_image_tokens = self.transformer.patch_embedding(input_img_latents,
- is_input_image=True)
-
- for b_inx in input_image_sizes.keys():
- for start_inx, end_inx in input_image_sizes[b_inx]:
- # replace the placeholder in text tokens with the image embedding.
- condition_tokens[b_inx, start_inx: end_inx] = input_image_tokens[input_img_inx].to(
- condition_tokens.dtype)
- input_img_inx += 1
-
- return condition_tokens
def check_inputs(
self,
@@ -243,6 +206,8 @@ def check_inputs(
input_images,
height,
width,
+ use_kv_cache,
+ offload_kv_cache,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
@@ -263,6 +228,12 @@ def check_inputs(
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
+
+ if use_kv_cache and offload_kv_cache:
+ if not torch.cuda.is_available():
+ raise ValueError(
+ f"Don't fine avaliable GPUs. `offload_kv_cache` can't be used when there is no GPU. please set it to False: `use_kv_cache=False, offload_kv_cache=False`"
+ )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -475,6 +446,8 @@ def __call__(
input_images,
height,
width,
+ use_kv_cache,
+ offload_kv_cache,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
@@ -496,16 +469,12 @@ def __call__(
width=width,
use_img_cfg=use_img_cfg,
use_input_image_size_as_output=use_input_image_size_as_output)
+ processed_data['input_ids'] = processed_data['input_ids'].to(device)
processed_data['attention_mask'] = processed_data['attention_mask'].to(device)
processed_data['position_ids'] = processed_data['position_ids'].to(device)
- # 4. Encode input images and obtain multi-modal conditional embeddings
+ # 4. Encode input images
input_img_latents = self.encod_input_iamges(processed_data['input_pixel_values'], device=device)
- condition_tokens = self.get_multimodal_embeddings(input_ids=processed_data['input_ids'],
- input_img_latents=input_img_latents,
- input_image_sizes=processed_data['input_image_sizes'],
- device=device,
- )
# 5. Prepare timesteps
sigmas = np.linspace(1, 0, num_inference_steps+1)[:num_inference_steps]
@@ -522,7 +491,7 @@ def __call__(
latent_channels,
height,
width,
- condition_tokens.dtype,
+ self.transformer.dtype,
device,
generator,
latents,
@@ -545,7 +514,9 @@ def __call__(
noise_pred, cache = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
- condition_tokens=condition_tokens,
+ input_ids=processed_data['input_ids'],
+ input_img_latents=input_img_latents,
+ input_image_sizes=processed_data['input_image_sizes'],
attention_mask=processed_data['attention_mask'],
position_ids=processed_data['position_ids'],
attention_kwargs=attention_kwargs,
@@ -556,8 +527,8 @@ def __call__(
# if use kv cache, don't need attention mask and position ids of condition tokens for next step
if use_kv_cache:
- if condition_tokens is not None:
- condition_tokens = None
+ if processed_data['input_ids'] is not None:
+ processed_data['input_ids'] = None
processed_data['attention_mask'] = processed_data['attention_mask'][..., -(num_tokens_for_output_img + 1):, :] # +1 is for the timestep token
processed_data['position_ids'] = processed_data['position_ids'][:, -(num_tokens_for_output_img + 1):]
diff --git a/test.py b/test.py
index fe1410122acd..b27d99a1066b 100644
--- a/test.py
+++ b/test.py
@@ -1,5 +1,5 @@
-# import os
-# os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2'
+import os
+os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2'
# from huggingface_hub import snapshot_download
@@ -47,10 +47,13 @@
# model.load_state_dict(new_ckpt)
+from tests.pipelines.omnigen.test_pipeline_omnigen import OmniGenPipelineFastTests, OmniGenPipelineSlowTests
+test1 = OmniGenPipelineFastTests()
+test1.test_inference()
-
-
+test2 = OmniGenPipelineSlowTests()
+test2.test_omnigen_inference()
diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py
index 73124576ee04..93870f9da31d 100644
--- a/tests/pipelines/omnigen/test_pipeline_omnigen.py
+++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py
@@ -169,10 +169,20 @@ def get_dummy_components(self):
torch.manual_seed(0)
- vae = AutoencoderKL()
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4, 4, 4, 4),
+ layers_per_block=1,
+ latent_channels=4,
+ norm_num_groups=1,
+ up_block_types = ["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
+ )
scheduler = FlowMatchEulerDiscreteScheduler()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
+ # tokenizer = AutoTokenizer.from_pretrained("Shitao/OmniGen-v1")
components = {
"transformer": transformer.eval(),
@@ -192,10 +202,12 @@ def get_dummy_inputs(self, device, seed=0):
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
- "guidance_scale": 5.0,
+ "guidance_scale": 3.0,
"output_type": "np",
"height": 16,
"width": 16,
+ "use_kv_cache": False,
+ "offload_kv_cache": False,
}
return inputs
@@ -205,7 +217,7 @@ def test_inference(self):
inputs = self.get_dummy_inputs(torch_device)
generated_image = pipe(**inputs).images[0]
- self.assertEqual(generated_image.shape, (1, 3, 16, 16))
+ self.assertEqual(generated_image.shape, (16, 16, 3))
@@ -235,7 +247,7 @@ def get_inputs(self, device, seed=0):
return {
"prompt": "A photo of a cat",
"num_inference_steps": 2,
- "guidance_scale": 5.0,
+ "guidance_scale": 2.5,
"output_type": "np",
"generator": generator,
}
@@ -248,19 +260,18 @@ def test_omnigen_inference(self):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
+
expected_slice = np.array(
- [
- [0.17773438, 0.18554688, 0.22070312],
- [0.046875, 0.06640625, 0.10351562],
- [0.0, 0.0, 0.02148438],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- ],
+ [[0.25806782, 0.28012177, 0.27807158],
+ [0.25740036, 0.2677201, 0.26857468],
+ [0.258638, 0.27035138, 0.26633185],
+ [0.2541029, 0.2636156, 0.26373306],
+ [0.24975497, 0.2608987, 0.2617477 ],
+ [0.25102, 0.26192215, 0.262023 ],
+ [0.24452701, 0.25664824, 0.259144 ],
+ [0.2419573, 0.2574909, 0.25996095],
+ [0.23953134, 0.25292695, 0.25652167],
+ [0.23523712, 0.24710432, 0.25460982]],
dtype=np.float32,
)
From 4c5e8c572118ede4bf402648c16616deea35856e Mon Sep 17 00:00:00 2001
From: shitao <2906698981@qq.com>
Date: Thu, 5 Dec 2024 15:27:59 +0800
Subject: [PATCH 08/55] update docs
---
.../en/using-diffusers/multimodal2img.md | 605 ++++++++++++++++++
docs/source/en/using-diffusers/omnigen.md | 70 ++
2 files changed, 675 insertions(+)
create mode 100644 docs/source/en/using-diffusers/multimodal2img.md
create mode 100644 docs/source/en/using-diffusers/omnigen.md
diff --git a/docs/source/en/using-diffusers/multimodal2img.md b/docs/source/en/using-diffusers/multimodal2img.md
new file mode 100644
index 000000000000..4618731830df
--- /dev/null
+++ b/docs/source/en/using-diffusers/multimodal2img.md
@@ -0,0 +1,605 @@
+
+
+# Image-to-image
+
+[[open-in-colab]]
+
+Image-to-image is similar to [text-to-image](conditional_image_generation), but in addition to a prompt, you can also pass an initial image as a starting point for the diffusion process. The initial image is encoded to latent space and noise is added to it. Then the latent diffusion model takes a prompt and the noisy latent image, predicts the added noise, and removes the predicted noise from the initial latent image to get the new latent image. Lastly, a decoder decodes the new latent image back into an image.
+
+With 🤗 Diffusers, this is as easy as 1-2-3:
+
+1. Load a checkpoint into the [`AutoPipelineForImage2Image`] class; this pipeline automatically handles loading the correct pipeline class based on the checkpoint:
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+```
+
+
+
+You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, then you don't need to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention).
+
+
+
+2. Load an image to pass to the pipeline:
+
+```py
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
+```
+
+3. Pass a prompt and image to the pipeline to generate an image:
+
+```py
+prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
+image = pipeline(prompt, image=init_image).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+ initial image
+
+
+
+ generated image
+
+
+
+## Popular models
+
+The most popular image-to-image models are [Stable Diffusion v1.5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5), [Stable Diffusion XL (SDXL)](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), and [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder). The results from the Stable Diffusion and Kandinsky models vary due to their architecture differences and training process; you can generally expect SDXL to produce higher quality images than Stable Diffusion v1.5. Let's take a quick look at how to use each of these models and compare their results.
+
+### Stable Diffusion v1.5
+
+Stable Diffusion v1.5 is a latent diffusion model initialized from an earlier checkpoint, and further finetuned for 595K steps on 512x512 images. To use this pipeline for image-to-image, you'll need to prepare an initial image to pass to the pipeline. Then you can pass a prompt and the image to the pipeline to generate a new image:
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+# pass prompt and image to pipeline
+image = pipeline(prompt, image=init_image).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+ initial image
+
+
+
+ generated image
+
+
+
+### Stable Diffusion XL (SDXL)
+
+SDXL is a more powerful version of the Stable Diffusion model. It uses a larger base model, and an additional refiner model to increase the quality of the base model's output. Read the [SDXL](sdxl) guide for a more detailed walkthrough of how to use this model, and other techniques it uses to produce high quality images.
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png"
+init_image = load_image(url)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+# pass prompt and image to pipeline
+image = pipeline(prompt, image=init_image, strength=0.5).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+ initial image
+
+
+
+ generated image
+
+
+
+### Kandinsky 2.2
+
+The Kandinsky model is different from the Stable Diffusion models because it uses an image prior model to create image embeddings. The embeddings help create a better alignment between text and images, allowing the latent diffusion model to generate better images.
+
+The simplest way to use Kandinsky 2.2 is:
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+# pass prompt and image to pipeline
+image = pipeline(prompt, image=init_image).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+ initial image
+
+
+
+ generated image
+
+
+
+## Configure pipeline parameters
+
+There are several important parameters you can configure in the pipeline that'll affect the image generation process and image quality. Let's take a closer look at what these parameters do and how changing them affects the output.
+
+### Strength
+
+`strength` is one of the most important parameters to consider and it'll have a huge impact on your generated image. It determines how much the generated image resembles the initial image. In other words:
+
+- 📈 a higher `strength` value gives the model more "creativity" to generate an image that's different from the initial image; a `strength` value of 1.0 means the initial image is more or less ignored
+- 📉 a lower `strength` value means the generated image is more similar to the initial image
+
+The `strength` and `num_inference_steps` parameters are related because `strength` determines the number of noise steps to add. For example, if the `num_inference_steps` is 50 and `strength` is 0.8, then this means adding 40 (50 * 0.8) steps of noise to the initial image and then denoising for 40 steps to get the newly generated image.
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+# pass prompt and image to pipeline
+image = pipeline(prompt, image=init_image, strength=0.8).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+ strength = 0.4
+
+
+
+ strength = 0.6
+
+
+
+ strength = 1.0
+
+
+
+### Guidance scale
+
+The `guidance_scale` parameter is used to control how closely aligned the generated image and text prompt are. A higher `guidance_scale` value means your generated image is more aligned with the prompt, while a lower `guidance_scale` value means your generated image has more space to deviate from the prompt.
+
+You can combine `guidance_scale` with `strength` for even more precise control over how expressive the model is. For example, combine a high `strength + guidance_scale` for maximum creativity or use a combination of low `strength` and low `guidance_scale` to generate an image that resembles the initial image but is not as strictly bound to the prompt.
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+# pass prompt and image to pipeline
+image = pipeline(prompt, image=init_image, guidance_scale=8.0).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+ guidance_scale = 0.1
+
+
+
+ guidance_scale = 5.0
+
+
+
+ guidance_scale = 10.0
+
+
+
+### Negative prompt
+
+A negative prompt conditions the model to *not* include things in an image, and it can be used to improve image quality or modify an image. For example, you can improve image quality by including negative prompts like "poor details" or "blurry" to encourage the model to generate a higher quality image. Or you can modify an image by specifying things to exclude from an image.
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
+
+# pass prompt and image to pipeline
+image = pipeline(prompt, negative_prompt=negative_prompt, image=init_image).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+## Chained image-to-image pipelines
+
+There are some other interesting ways you can use an image-to-image pipeline aside from just generating an image (although that is pretty cool too). You can take it a step further and chain it with other pipelines.
+
+### Text-to-image-to-image
+
+Chaining a text-to-image and image-to-image pipeline allows you to generate an image from text and use the generated image as the initial image for the image-to-image pipeline. This is useful if you want to generate an image entirely from scratch. For example, let's chain a Stable Diffusion and a Kandinsky model.
+
+Start by generating an image with the text-to-image pipeline:
+
+```py
+from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
+import torch
+from diffusers.utils import make_image_grid
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+text2image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k").images[0]
+text2image
+```
+
+Now you can pass this generated image to the image-to-image pipeline:
+
+```py
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+image2image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=text2image).images[0]
+make_image_grid([text2image, image2image], rows=1, cols=2)
+```
+
+### Image-to-image-to-image
+
+You can also chain multiple image-to-image pipelines together to create more interesting images. This can be useful for iteratively performing style transfer on an image, generating short GIFs, restoring color to an image, or restoring missing areas of an image.
+
+Start by generating an image:
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+# pass prompt and image to pipeline
+image = pipeline(prompt, image=init_image, output_type="latent").images[0]
+```
+
+
+
+It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.
+
+
+
+Pass the latent output from this pipeline to the next pipeline to generate an image in a [comic book art style](https://huggingface.co/ogkalu/Comic-Diffusion):
+
+```py
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "ogkalu/Comic-Diffusion", torch_dtype=torch.float16
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# need to include the token "charliebo artstyle" in the prompt to use this checkpoint
+image = pipeline("Astronaut in a jungle, charliebo artstyle", image=image, output_type="latent").images[0]
+```
+
+Repeat one more time to generate the final image in a [pixel art style](https://huggingface.co/kohbanye/pixel-art-style):
+
+```py
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "kohbanye/pixel-art-style", torch_dtype=torch.float16
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# need to include the token "pixelartstyle" in the prompt to use this checkpoint
+image = pipeline("Astronaut in a jungle, pixelartstyle", image=image).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+### Image-to-upscaler-to-super-resolution
+
+Another way you can chain your image-to-image pipeline is with an upscaler and super-resolution pipeline to really increase the level of details in an image.
+
+Start with an image-to-image pipeline:
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+# pass prompt and image to pipeline
+image_1 = pipeline(prompt, image=init_image, output_type="latent").images[0]
+```
+
+
+
+It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in *latent* space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.
+
+
+
+Chain it to an upscaler pipeline to increase the image resolution:
+
+```py
+from diffusers import StableDiffusionLatentUpscalePipeline
+
+upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
+ "stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+upscaler.enable_model_cpu_offload()
+upscaler.enable_xformers_memory_efficient_attention()
+
+image_2 = upscaler(prompt, image=image_1, output_type="latent").images[0]
+```
+
+Finally, chain it to a super-resolution pipeline to further enhance the resolution:
+
+```py
+from diffusers import StableDiffusionUpscalePipeline
+
+super_res = StableDiffusionUpscalePipeline.from_pretrained(
+ "stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+super_res.enable_model_cpu_offload()
+super_res.enable_xformers_memory_efficient_attention()
+
+image_3 = super_res(prompt, image=image_2).images[0]
+make_image_grid([init_image, image_3.resize((512, 512))], rows=1, cols=2)
+```
+
+## Control image generation
+
+Trying to generate an image that looks exactly the way you want can be difficult, which is why controlled generation techniques and models are so useful. While you can use the `negative_prompt` to partially control image generation, there are more robust methods like prompt weighting and ControlNets.
+
+### Prompt weighting
+
+Prompt weighting allows you to scale the representation of each concept in a prompt. For example, in a prompt like "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", you can choose to increase or decrease the embeddings of "astronaut" and "jungle". The [Compel](https://github.com/damian0815/compel) library provides a simple syntax for adjusting prompt weights and generating the embeddings. You can learn how to create the embeddings in the [Prompt weighting](weighted_prompts) guide.
+
+[`AutoPipelineForImage2Image`] has a `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter where you can pass the embeddings which replaces the `prompt` parameter.
+
+```py
+from diffusers import AutoPipelineForImage2Image
+import torch
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+image = pipeline(prompt_embeds=prompt_embeds, # generated from Compel
+ negative_prompt_embeds=negative_prompt_embeds, # generated from Compel
+ image=init_image,
+).images[0]
+```
+
+### ControlNet
+
+ControlNets provide a more flexible and accurate way to control image generation because you can use an additional conditioning image. The conditioning image can be a canny image, depth map, image segmentation, and even scribbles! Whatever type of conditioning image you choose, the ControlNet generates an image that preserves the information in it.
+
+For example, let's condition an image with a depth map to keep the spatial information in the image.
+
+```py
+from diffusers.utils import load_image, make_image_grid
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+init_image = init_image.resize((958, 960)) # resize to depth image dimensions
+depth_image = load_image("https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png")
+make_image_grid([init_image, depth_image], rows=1, cols=2)
+```
+
+Load a ControlNet model conditioned on depth maps and the [`AutoPipelineForImage2Image`]:
+
+```py
+from diffusers import ControlNetModel, AutoPipelineForImage2Image
+import torch
+
+controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+```
+
+Now generate a new image conditioned on the depth map, initial image, and prompt:
+
+```py
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+image_control_net = pipeline(prompt, image=init_image, control_image=depth_image).images[0]
+make_image_grid([init_image, depth_image, image_control_net], rows=1, cols=3)
+```
+
+
+
+
+ initial image
+
+
+
+ depth image
+
+
+
+ ControlNet image
+
+
+
+Let's apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion) to the image generated from the ControlNet by chaining it with an image-to-image pipeline:
+
+```py
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16,
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+prompt = "elden ring style astronaut in a jungle" # include the token "elden ring style" in the prompt
+negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
+
+image_elden_ring = pipeline(prompt, negative_prompt=negative_prompt, image=image_control_net, strength=0.45, guidance_scale=10.5).images[0]
+make_image_grid([init_image, depth_image, image_control_net, image_elden_ring], rows=2, cols=2)
+```
+
+
+
+
+
+## Optimize
+
+Running diffusion models is computationally expensive and intensive, but with a few optimization tricks, it is entirely possible to run them on consumer and free-tier GPUs. For example, you can use a more memory-efficient form of attention such as PyTorch 2.0's [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) or [xFormers](../optimization/xformers) (you can use one or the other, but there's no need to use both). You can also offload the model to the GPU while the other pipeline components wait on the CPU.
+
+```diff
++ pipeline.enable_model_cpu_offload()
++ pipeline.enable_xformers_memory_efficient_attention()
+```
+
+With [`torch.compile`](../optimization/torch2.0#torchcompile), you can boost your inference speed even more by wrapping your UNet with it:
+
+```py
+pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
+```
+
+To learn more, take a look at the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md
new file mode 100644
index 000000000000..5b5e7119a94c
--- /dev/null
+++ b/docs/source/en/using-diffusers/omnigen.md
@@ -0,0 +1,70 @@
+
+# OmniGen
+
+OmniGen is an image generation model. Unlike existing text-to-image models, OmniGen is designed to handle a variety of tasks (e.g., text-to-image, image editing, controllable generation) within a single model. It has the following features:
+- Minimalist model architecture, consisting of only a VAE and a transformer module, for joint modeling of text and images.
+- Support for multimodal inputs. It can process any text-image mixed data as instructions for image generation, rather than relying solely on text.
+
+This guide will walk you through using OmniGen for various tasks and use cases.
+
+## Load model checkpoints
+Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
+
+```py
+import torch
+from diffusers import OmniGenPipeline
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+```
+
+
+## Text-to-Image
+
+
+## Text-to-image
+
+For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image.
+You can try setting the `height` and `width` parameters to generate images with different size.
+
+```py
+import torch
+from diffusers import OmniGenPipeline
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+
+prompt = "An elderly gentleman, with a serene expression, sits at the water's edge, a steaming cup of tea by his side. He is engrossed in his artwork, brush in hand, as he renders an oil painting on a canvas that's propped up against a small, weathered table. The sea breeze whispers through his silver hair, gently billowing his loose-fitting white shirt, while the salty air adds an intangible element to his masterpiece in progress. The scene is one of tranquility and inspiration, with the artist's canvas capturing the vibrant hues of the setting sun reflecting off the tranquil sea."
+pipe.enable_model_cpu_offload()
+
+image = pipe(
+ prompt=prompt,
+ generator=torch.Generator(device="cuda").manual_seed(42),
+).images[0]
+
+image
+```
+
+
+
+For text-to-image, pass a text prompt. By default, CogVideoX generates a 720x480 video for the best results.
+
+
+
+## Optimization
+
+### inference speed
+
+### Memory
\ No newline at end of file
From d9f80fcdbdda912d21873b7117d4abaffaa3cb7c Mon Sep 17 00:00:00 2001
From: shitao <2906698981@qq.com>
Date: Thu, 5 Dec 2024 22:41:14 +0800
Subject: [PATCH 09/55] update docs
---
docs/source/en/_toctree.yml | 8 ++++++++
.../en/api/models/omnigen_transformer.md | 19 +++++++++++++++++++
docs/source/en/using-diffusers/omnigen.md | 2 +-
src/diffusers/models/transformers/__init__.py | 2 +-
.../transformers/transformer_omnigen.py | 9 ++-------
.../pipelines/lumina/pipeline_lumina.py | 2 +-
6 files changed, 32 insertions(+), 10 deletions(-)
create mode 100644 docs/source/en/api/models/omnigen_transformer.md
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 2faabfec30ce..c181f0bd82a0 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -51,6 +51,8 @@
title: Text or image-to-video
- local: using-diffusers/depth2img
title: Depth-to-image
+ - local: using-diffusers/multimodal2img
+ title: Multimodal-to-image
title: Generative tasks
- sections:
- local: using-diffusers/overview_techniques
@@ -87,6 +89,8 @@
title: Kandinsky
- local: using-diffusers/ip_adapter
title: IP-Adapter
+ - local: using-diffusers/omnigen
+ title: OmniGen
- local: using-diffusers/pag
title: PAG
- local: using-diffusers/controlnet
@@ -274,6 +278,8 @@
title: LuminaNextDiT2DModel
- local: api/models/mochi_transformer3d
title: MochiTransformer3DModel
+ - local: api/models/omnigen_transformer
+ title: OmniGenTransformer2DModel
- local: api/models/pixart_transformer2d
title: PixArtTransformer2DModel
- local: api/models/prior_transformer
@@ -412,6 +418,8 @@
title: MultiDiffusion
- local: api/pipelines/musicldm
title: MusicLDM
+ - local: api/pipelines/omnigen
+ title: OmniGen
- local: api/pipelines/pag
title: PAG
- local: api/pipelines/paint_by_example
diff --git a/docs/source/en/api/models/omnigen_transformer.md b/docs/source/en/api/models/omnigen_transformer.md
new file mode 100644
index 000000000000..d2df6c55e68b
--- /dev/null
+++ b/docs/source/en/api/models/omnigen_transformer.md
@@ -0,0 +1,19 @@
+
+
+# OmniGenTransformer2DModel
+
+A Transformer model accept multi-modal instruction to generate image from [OmniGen](https://github.com/VectorSpaceLab/OmniGen/).
+
+## OmniGenTransformer2DModel
+
+[[autodoc]] OmniGenTransformer2DModel
diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md
index 5b5e7119a94c..ff398eb4f88e 100644
--- a/docs/source/en/using-diffusers/omnigen.md
+++ b/docs/source/en/using-diffusers/omnigen.md
@@ -46,7 +46,7 @@ pipe = OmniGenPipeline.from_pretrained(
torch_dtype=torch.bfloat16
)
-prompt = "An elderly gentleman, with a serene expression, sits at the water's edge, a steaming cup of tea by his side. He is engrossed in his artwork, brush in hand, as he renders an oil painting on a canvas that's propped up against a small, weathered table. The sea breeze whispers through his silver hair, gently billowing his loose-fitting white shirt, while the salty air adds an intangible element to his masterpiece in progress. The scene is one of tranquility and inspiration, with the artist's canvas capturing the vibrant hues of the setting sun reflecting off the tranquil sea."
+prompt = "A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD."
pipe.enable_model_cpu_offload()
image = pipe(
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 9770ded5e31e..1f16146fce22 100644
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -18,6 +18,6 @@
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_mochi import MochiTransformer3DModel
+ from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
- from .transformer_omnigen import OmniGenTransformer2DModel
diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py
index 64790a2024e5..5095516e57ab 100644
--- a/src/diffusers/models/transformers/transformer_omnigen.py
+++ b/src/diffusers/models/transformers/transformer_omnigen.py
@@ -225,15 +225,10 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
Reference: https://arxiv.org/pdf/2409.11340
Parameters:
+ transformer_config (`dict`): config for transformer layers. OmniGen-v1 use Phi3 as transformer backbone
patch_size (`int`, defaults to 2): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
- num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
- num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
- attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
- num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
- joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
- pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
- guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
+ pos_embed_max_size (`int`, *optional*, defaults to 192): The max size of pos emb.
"""
_supports_gradient_checkpointing = True
diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py
index 296ae5303b20..1db3d58808d2 100644
--- a/src/diffusers/pipelines/lumina/pipeline_lumina.py
+++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py
@@ -25,7 +25,7 @@
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL
from ...models.embeddings import get_2d_rotary_pos_embed_lumina
-from ...models.transformers.lumina_nextdit2d import LuminaextDiT2DModel
+from ...models.transformers.lumina_nextdit2d import LuminaNextDiT2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
BACKENDS_MAPPING,
From c78d1f4b490950bcf940df722b76e5c541f3a587 Mon Sep 17 00:00:00 2001
From: shitao <2906698981@qq.com>
Date: Fri, 6 Dec 2024 18:24:58 +0800
Subject: [PATCH 10/55] offload_transformer
---
docs/source/en/using-diffusers/omnigen.md | 278 +++++++++++++++++-
.../transformers/transformer_omnigen.py | 4 +-
2 files changed, 270 insertions(+), 12 deletions(-)
diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md
index ff398eb4f88e..22a14bd95efe 100644
--- a/docs/source/en/using-diffusers/omnigen.md
+++ b/docs/source/en/using-diffusers/omnigen.md
@@ -15,6 +15,7 @@ OmniGen is an image generation model. Unlike existing text-to-image models, Omni
- Minimalist model architecture, consisting of only a VAE and a transformer module, for joint modeling of text and images.
- Support for multimodal inputs. It can process any text-image mixed data as instructions for image generation, rather than relying solely on text.
+For more information, please refer to the [paper](https://arxiv.org/pdf/2409.11340).
This guide will walk you through using OmniGen for various tasks and use cases.
## Load model checkpoints
@@ -30,8 +31,6 @@ pipe = OmniGenPipeline.from_pretrained(
```
-## Text-to-Image
-
## Text-to-image
@@ -41,30 +40,289 @@ You can try setting the `height` and `width` parameters to generate images with
```py
import torch
from diffusers import OmniGenPipeline
+
pipe = OmniGenPipeline.from_pretrained(
"Shitao/OmniGen-v1-diffusers",
torch_dtype=torch.bfloat16
)
+pipe.to("cuda")
-prompt = "A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD."
-pipe.enable_model_cpu_offload()
-
+prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD."
image = pipe(
prompt=prompt,
- generator=torch.Generator(device="cuda").manual_seed(42),
+ height=1024,
+ width=1024,
+ guidance_scale=3,
+ generator=torch.Generator(device="cpu").manual_seed(111),
).images[0]
+image
+```
+
+
+
+
+## Image edit
+
+OmniGen supports for multimodal inputs.
+When the input includes an image, you need to add a placeholder `<|image_1|>` in the text prompt to represent the image.
+It is recommended to enable 'use_input_image_size_as_output' to keep the edited image the same size as the original image.
+
+```py
+import torch
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt="<|image_1|> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola."
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(222)).images[0]
+image
+```
+
+
+
+ original image
+
+
+
+ edited image
+
+
+
+OmniGen has some interesting features, such as the ability to infer user needs, as shown in the example below.
+```py
+prompt="If the woman is thirsty, what should she take? Find it in the image and highlight it in blue. <|image_1|>"
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(0)).images[0]
image
```
-
+
-For text-to-image, pass a text prompt. By default, CogVideoX generates a 720x480 video for the best results.
+## Controllable generation
+
+ OmniGen can handle several classic computer vision tasks.
+ As shown below, OmniGen can detect human skeletons in input images, which can be used as control conditions to generate new images.
+
+```py
+import torch
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt="Detect the skeleton of human in this image: <|image_1|>"
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
+image1 = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(333)).images[0]
+image1
+
+prompt="Generate a new photo using the following picture and text as conditions: <|image_1|>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him."
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal.png")]
+image2 = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(333)).images[0]
+image2
+```
+
+
+
+
+ original image
+
+
+
+ detected skeleton
+
+
+
+ skeleton to image
+
+
+
+
+OmniGen can also directly use relevant information from input images to generate new images.
+```py
+import torch
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt="Following the pose of this image <|image_1|>, generate a new photo: A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him."
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(0)).images[0]
+image
+```
+
+
+
+ generated image
+
+
+
+
+## ID and object preserving
+
+OmniGen can generate multiple images based on the people and objects in the input image and supports inputting multiple images simultaneously.
+Additionally, OmniGen can extract desired objects from an image containing multiple objects based on instructions.
+
+```py
+import torch
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt="A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <|image_1|>. The woman is the woman on the left of <|image_2|>"
+input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.jpg")
+input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/4.jpg")
+input_images=[input_image_1, input_image_2]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ height=1024,
+ width=1024,
+ guidance_scale=2.5,
+ img_guidance_scale=1.6,
+ generator=torch.Generator(device="cpu").manual_seed(666)).images[0]
+image
+```
+
+
+
+ input_image_1
+
+
+
+ input_image_2
+
+
+
+ generated image
+
+
+
+
+```py
+import torch
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+
+prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <|image_1|>. The long-sleeve blouse and a pleated skirt are <|image_2|>."
+input_image_1 = load_image("/share/junjie/code/VISTA2/produce_data/laion_net/diffgpt/OmniGen/docs_img/emma.jpeg")
+input_image_2 = load_image("/share/junjie/code/VISTA2/produce_data/laion_net/diffgpt/OmniGen/docs_img/dress.jpg")
+input_images=[input_image_1, input_image_2]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ height=1024,
+ width=1024,
+ guidance_scale=2.5,
+ img_guidance_scale=1.6,
+ generator=torch.Generator(device="cpu").manual_seed(666)).images[0]
+```
+
+
+
+
+ person image
+
+
+
+ clothe image
+
+
+
+ generated image
+
+
+
+
+## Optimization when inputting multiple images
+
+For text-to-image task, OmniGen requires minimal memory and time costs (9G memory and 31s for a 1024*1024 image on A800 GPU).
+However, when using input images, the computational cost increases.
+
+Here are some guidelines to help you reduce computational costs when input multiple images. The experiments are conducted on A800 GPU and input two images to OmniGen.
-## Optimization
### inference speed
-### Memory
\ No newline at end of file
+- `use_kv_cache=True`:
+ `use_kv_cache` will store key and value states of the input conditions to compute attention without redundant computations.
+ The default value is True, and OmniGen will offload the kv cache to cpu default.
+ - `use_kv_cache=False`: the inference time is 3m21s.
+ - `use_kv_cache=True`: the inference time is 1m30s.
+
+- `max_input_image_size`:
+ the maximum size of input image, which will be used to crop the input image
+ - `max_input_image_size=1024`: the inference time is 1m30s.
+ - `max_input_image_size=512`: the inference time is 58s.
+
+### Memory
+
+- `pipe.enable_model_cpu_offload()`:
+ - Without enabling cpu offloading, memory usage is `31 GB`
+ - With enabling cpu offloading, memory usage is `28 GB`
+
+- `offload_transformer_block=True`:
+ - 17G
+
+- `pipe.enable_sequential_cpu_offload()`:
+ - 11G
+
+
diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py
index 5095516e57ab..eb61dfc3e592 100644
--- a/src/diffusers/models/transformers/transformer_omnigen.py
+++ b/src/diffusers/models/transformers/transformer_omnigen.py
@@ -174,7 +174,7 @@ def forward(
)
else:
if offload_transformer_block and not self.training:
- if not not torch.cuda.is_available():
+ if not torch.cuda.is_available():
logger.warning_once(
"We don't detecte any available GPU, so diable `offload_transformer_block`"
)
@@ -363,7 +363,7 @@ def get_multimodal_embeddings(self,
input_img_inx = 0
if input_img_latents is not None:
input_image_tokens = self.patch_embedding(input_img_latents,
- is_input_image=True)
+ is_input_image=True)
for b_inx in input_image_sizes.keys():
for start_inx, end_inx in input_image_sizes[b_inx]:
From 236f14b703d933092c088095ec03f60e94e8d4b1 Mon Sep 17 00:00:00 2001
From: staoxiao <2906698981@qq.com>
Date: Sun, 8 Dec 2024 15:00:28 +0800
Subject: [PATCH 11/55] enable_transformer_block_cpu_offload
---
.../transformers/transformer_omnigen.py | 13 +++++-------
.../pipelines/omnigen/pipeline_omnigen.py | 20 +++++++++++++++++--
2 files changed, 23 insertions(+), 10 deletions(-)
diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py
index eb61dfc3e592..882065faae3d 100644
--- a/src/diffusers/models/transformers/transformer_omnigen.py
+++ b/src/diffusers/models/transformers/transformer_omnigen.py
@@ -74,21 +74,18 @@ def evict_previous_layer(self, layer_idx: int):
prev_layer_idx = layer_idx - 1
for name, param in self.layers[prev_layer_idx].named_parameters():
param.data = param.data.to("cpu", non_blocking=True)
-
+
def get_offload_layer(self, layer_idx: int, device: torch.device):
# init stream
if not hasattr(self, "prefetch_stream"):
self.prefetch_stream = torch.cuda.Stream()
# delete previous layer
- # main stream sync shouldn't be necessary since all computation on iter i-1 is finished by iter i
- # torch.cuda.current_stream().synchronize()
- # avoid extra eviction of last layer
- if layer_idx > 0:
- self.evict_previous_layer(layer_idx)
-
+ torch.cuda.current_stream().synchronize()
+ self.evict_previous_layer(layer_idx)
+
# make sure the current layer is ready
- self.prefetch_stream.synchronize()
+ torch.cuda.synchronize(self.prefetch_stream)
# load next layer
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
index ff10cade50f3..902453aae0f6 100644
--- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
@@ -327,6 +327,18 @@ def num_timesteps(self):
@property
def interrupt(self):
return self._interrupt
+
+ def enable_transformer_block_cpu_offload(self, device: Union[torch.device, str] = "cuda"):
+ torch_device = torch.device(device)
+ for name, param in self.transformer.named_parameters():
+ if 'layers' in name and 'layers.0' not in name:
+ param.data = param.data.cpu()
+ else:
+ param.data = param.data.to(torch_device)
+ for buffer_name, buffer in self.transformer.patch_embedding.named_buffers():
+ setattr(self.transformer.patch_embedding, buffer_name, buffer.to(torch_device))
+ self.vae.to(torch_device)
+ self.offload_transformer_block = True
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
@@ -440,6 +452,9 @@ def __call__(
# using Float32 for the VAE doesn't take up much memory but can prevent potential black image outputs.
self.vae.to(torch.float32)
+ if offload_transformer_block:
+ self.enable_transformer_block_cpu_offload()
+
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
@@ -460,9 +475,10 @@ def __call__(
batch_size = len(prompt)
device = self._execution_device
+
# 3. process multi-modal instructions
if max_input_image_size != self.multimodal_processor.max_image_size:
- self.multimodal_processor = OmniGenMultiModalProcessor(self.text_tokenizer, max_image_size=max_input_image_size)
+ self.multimodal_processor = OmniGenMultiModalProcessor(self.tokenizer, max_image_size=max_input_image_size)
processed_data = self.multimodal_processor(prompt,
input_images,
height=height,
@@ -521,7 +537,7 @@ def __call__(
position_ids=processed_data['position_ids'],
attention_kwargs=attention_kwargs,
past_key_values=cache,
- offload_transformer_block=offload_transformer_block,
+ offload_transformer_block=self.offload_transformer_block if hasattr(self, 'offload_transformer_block') else offload_transformer_block,
return_dict=False,
)
From 6b52547ebcf83ea564fd59801c80affe5b53d944 Mon Sep 17 00:00:00 2001
From: shitao <2906698981@qq.com>
Date: Sun, 8 Dec 2024 16:06:38 +0800
Subject: [PATCH 12/55] update docs
---
docs/source/en/api/pipelines/omnigen.md | 66 +-
.../en/using-diffusers/multimodal2img.md | 612 ++----------------
docs/source/en/using-diffusers/omnigen.md | 7 +-
.../transformers/transformer_omnigen.py | 58 +-
.../pipelines/omnigen/kvcache_omnigen.py | 54 +-
.../pipelines/omnigen/pipeline_omnigen.py | 155 +++--
.../pipelines/omnigen/processor_omnigen.py | 16 +-
7 files changed, 268 insertions(+), 700 deletions(-)
diff --git a/docs/source/en/api/pipelines/omnigen.md b/docs/source/en/api/pipelines/omnigen.md
index fe0bc5d496a0..b7ea5488156a 100644
--- a/docs/source/en/api/pipelines/omnigen.md
+++ b/docs/source/en/api/pipelines/omnigen.md
@@ -56,44 +56,50 @@ First, load the pipeline:
```python
import torch
-from diffusers import CogVideoXPipeline, CogVideoXImageToVideoPipeline
-from diffusers.utils import export_to_video,load_image
-pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b").to("cuda") # or "THUDM/CogVideoX-2b"
+from diffusers import OmniGenPipeline
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
```
-If you are using the image-to-video pipeline, load it as follows:
-
-```python
-pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V").to("cuda")
+For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image.
+You can try setting the `height` and `width` parameters to generate images with different size.
+
+```py
+prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD."
+image = pipe(
+ prompt=prompt,
+ height=1024,
+ width=1024,
+ guidance_scale=3,
+ generator=torch.Generator(device="cpu").manual_seed(111),
+).images[0]
+image
```
-Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`:
-
-```python
-pipe.transformer.to(memory_format=torch.channels_last)
-```
-
-Compile the components and run inference:
-
-```python
-pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
-
-# CogVideoX works well with long and well-described prompts
-prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
-video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
-```
-
-The [T2V benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are:
-
-```
-Without torch.compile(): Average inference time: 96.89 seconds.
-With torch.compile(): Average inference time: 76.27 seconds.
+OmniGen supports for multimodal inputs.
+When the input includes an image, you need to add a placeholder `<|image_1|>` in the text prompt to represent the image.
+It is recommended to enable 'use_input_image_size_as_output' to keep the edited image the same size as the original image.
+
+```py
+prompt="<|image_1|> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola."
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(222)).images[0]
+image
```
-## CogVideoXPipeline
+## OmniGenPipeline
-[[autodoc]] CogVideoXPipeline
+[[autodoc]] OmniGenPipeline
- all
- __call__
diff --git a/docs/source/en/using-diffusers/multimodal2img.md b/docs/source/en/using-diffusers/multimodal2img.md
index 4618731830df..10567d49ae19 100644
--- a/docs/source/en/using-diffusers/multimodal2img.md
+++ b/docs/source/en/using-diffusers/multimodal2img.md
@@ -10,596 +10,104 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Image-to-image
+# Multi-modal instruction to image
[[open-in-colab]]
-Image-to-image is similar to [text-to-image](conditional_image_generation), but in addition to a prompt, you can also pass an initial image as a starting point for the diffusion process. The initial image is encoded to latent space and noise is added to it. Then the latent diffusion model takes a prompt and the noisy latent image, predicts the added noise, and removes the predicted noise from the initial latent image to get the new latent image. Lastly, a decoder decodes the new latent image back into an image.
-With 🤗 Diffusers, this is as easy as 1-2-3:
-1. Load a checkpoint into the [`AutoPipelineForImage2Image`] class; this pipeline automatically handles loading the correct pipeline class based on the checkpoint:
+Multimodal instructions mean you can input any sequence of mixed text and images to guide image generation. You can input multiple images and use prompts to describe the desired output. This approach is more flexible than using only text or images.
-```py
-import torch
-from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import load_image, make_image_grid
-
-pipeline = AutoPipelineForImage2Image.from_pretrained(
- "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True
-)
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-```
-
-
-
-You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, then you don't need to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention).
-
-
-
-2. Load an image to pass to the pipeline:
-
-```py
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
-```
-
-3. Pass a prompt and image to the pipeline to generate an image:
-
-```py
-prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
-image = pipeline(prompt, image=init_image).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
-```
-
-
-
-
- initial image
-
-
-
- generated image
-
-
-
-## Popular models
-
-The most popular image-to-image models are [Stable Diffusion v1.5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5), [Stable Diffusion XL (SDXL)](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), and [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder). The results from the Stable Diffusion and Kandinsky models vary due to their architecture differences and training process; you can generally expect SDXL to produce higher quality images than Stable Diffusion v1.5. Let's take a quick look at how to use each of these models and compare their results.
+## Examples
-### Stable Diffusion v1.5
-Stable Diffusion v1.5 is a latent diffusion model initialized from an earlier checkpoint, and further finetuned for 595K steps on 512x512 images. To use this pipeline for image-to-image, you'll need to prepare an initial image to pass to the pipeline. Then you can pass a prompt and the image to the pipeline to generate a new image:
+Take `OmniGenPipeline` as an example: the input can be a text-image sequence to create new images, he input can be a text-image sequence, with images inserted into the text prompt via special placeholder `<|image_i|>`.
-```py
+```py
import torch
-from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import make_image_grid, load_image
-
-pipeline = AutoPipelineForImage2Image.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-# prepare image
-url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-init_image = load_image(url)
-
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-
-# pass prompt and image to pipeline
-image = pipeline(prompt, image=init_image).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt="A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <|image_1|>. The woman is the woman on the left of <|image_2|>"
+input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.jpg")
+input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/4.jpg")
+input_images=[input_image_1, input_image_2]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ height=1024,
+ width=1024,
+ guidance_scale=2.5,
+ img_guidance_scale=1.6,
+ generator=torch.Generator(device="cpu").manual_seed(666)).images[0]
+image
```
-
-
-
-
- initial image
-
-
-
- generated image
-
-
-
-### Stable Diffusion XL (SDXL)
-
-SDXL is a more powerful version of the Stable Diffusion model. It uses a larger base model, and an additional refiner model to increase the quality of the base model's output. Read the [SDXL](sdxl) guide for a more detailed walkthrough of how to use this model, and other techniques it uses to produce high quality images.
-
-```py
-import torch
-from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import make_image_grid, load_image
-
-pipeline = AutoPipelineForImage2Image.from_pretrained(
- "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-# prepare image
-url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png"
-init_image = load_image(url)
-
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-
-# pass prompt and image to pipeline
-image = pipeline(prompt, image=init_image, strength=0.5).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
-```
-
-
-
-
- initial image
-
-
-
- generated image
-
-
-
-### Kandinsky 2.2
-
-The Kandinsky model is different from the Stable Diffusion models because it uses an image prior model to create image embeddings. The embeddings help create a better alignment between text and images, allowing the latent diffusion model to generate better images.
-
-The simplest way to use Kandinsky 2.2 is:
-
-```py
-import torch
-from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import make_image_grid, load_image
-
-pipeline = AutoPipelineForImage2Image.from_pretrained(
- "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True
-)
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-# prepare image
-url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-init_image = load_image(url)
-
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-
-# pass prompt and image to pipeline
-image = pipeline(prompt, image=init_image).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
-```
-
-
-
-
- initial image
-
-
-
- generated image
-
-
-
-## Configure pipeline parameters
-
-There are several important parameters you can configure in the pipeline that'll affect the image generation process and image quality. Let's take a closer look at what these parameters do and how changing them affects the output.
-
-### Strength
-
-`strength` is one of the most important parameters to consider and it'll have a huge impact on your generated image. It determines how much the generated image resembles the initial image. In other words:
-
-- 📈 a higher `strength` value gives the model more "creativity" to generate an image that's different from the initial image; a `strength` value of 1.0 means the initial image is more or less ignored
-- 📉 a lower `strength` value means the generated image is more similar to the initial image
-
-The `strength` and `num_inference_steps` parameters are related because `strength` determines the number of noise steps to add. For example, if the `num_inference_steps` is 50 and `strength` is 0.8, then this means adding 40 (50 * 0.8) steps of noise to the initial image and then denoising for 40 steps to get the newly generated image.
-
-```py
-import torch
-from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import make_image_grid, load_image
-
-pipeline = AutoPipelineForImage2Image.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-# prepare image
-url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-init_image = load_image(url)
-
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-
-# pass prompt and image to pipeline
-image = pipeline(prompt, image=init_image, strength=0.8).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
-```
-
-
-
-
- strength = 0.4
-
-
-
- strength = 0.6
-
-
-
- strength = 1.0
-
-
-
-### Guidance scale
-
-The `guidance_scale` parameter is used to control how closely aligned the generated image and text prompt are. A higher `guidance_scale` value means your generated image is more aligned with the prompt, while a lower `guidance_scale` value means your generated image has more space to deviate from the prompt.
-
-You can combine `guidance_scale` with `strength` for even more precise control over how expressive the model is. For example, combine a high `strength + guidance_scale` for maximum creativity or use a combination of low `strength` and low `guidance_scale` to generate an image that resembles the initial image but is not as strictly bound to the prompt.
-
-```py
-import torch
-from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import make_image_grid, load_image
-
-pipeline = AutoPipelineForImage2Image.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-# prepare image
-url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-init_image = load_image(url)
-
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-
-# pass prompt and image to pipeline
-image = pipeline(prompt, image=init_image, guidance_scale=8.0).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
-```
-
-
- guidance_scale = 0.1
+
+ input_image_1
-
- guidance_scale = 5.0
+
+ input_image_2
-
- guidance_scale = 10.0
-
-
-
-### Negative prompt
-
-A negative prompt conditions the model to *not* include things in an image, and it can be used to improve image quality or modify an image. For example, you can improve image quality by including negative prompts like "poor details" or "blurry" to encourage the model to generate a higher quality image. Or you can modify an image by specifying things to exclude from an image.
-
-```py
-import torch
-from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import make_image_grid, load_image
-
-pipeline = AutoPipelineForImage2Image.from_pretrained(
- "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-# prepare image
-url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-init_image = load_image(url)
-
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
-
-# pass prompt and image to pipeline
-image = pipeline(prompt, negative_prompt=negative_prompt, image=init_image).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
-```
-
-
-## Chained image-to-image pipelines
-
-There are some other interesting ways you can use an image-to-image pipeline aside from just generating an image (although that is pretty cool too). You can take it a step further and chain it with other pipelines.
-
-### Text-to-image-to-image
-
-Chaining a text-to-image and image-to-image pipeline allows you to generate an image from text and use the generated image as the initial image for the image-to-image pipeline. This is useful if you want to generate an image entirely from scratch. For example, let's chain a Stable Diffusion and a Kandinsky model.
-
-Start by generating an image with the text-to-image pipeline:
-
-```py
-from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
-import torch
-from diffusers.utils import make_image_grid
-
-pipeline = AutoPipelineForText2Image.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-text2image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k").images[0]
-text2image
-```
-
-Now you can pass this generated image to the image-to-image pipeline:
-
-```py
-pipeline = AutoPipelineForImage2Image.from_pretrained(
- "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True
-)
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-image2image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=text2image).images[0]
-make_image_grid([text2image, image2image], rows=1, cols=2)
-```
-
-### Image-to-image-to-image
-
-You can also chain multiple image-to-image pipelines together to create more interesting images. This can be useful for iteratively performing style transfer on an image, generating short GIFs, restoring color to an image, or restoring missing areas of an image.
-
-Start by generating an image:
-
-```py
-import torch
-from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import make_image_grid, load_image
-
-pipeline = AutoPipelineForImage2Image.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-# prepare image
-url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-init_image = load_image(url)
-
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-
-# pass prompt and image to pipeline
-image = pipeline(prompt, image=init_image, output_type="latent").images[0]
-```
-
-
-
-It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.
-
-
-
-Pass the latent output from this pipeline to the next pipeline to generate an image in a [comic book art style](https://huggingface.co/ogkalu/Comic-Diffusion):
-
-```py
-pipeline = AutoPipelineForImage2Image.from_pretrained(
- "ogkalu/Comic-Diffusion", torch_dtype=torch.float16
-)
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-# need to include the token "charliebo artstyle" in the prompt to use this checkpoint
-image = pipeline("Astronaut in a jungle, charliebo artstyle", image=image, output_type="latent").images[0]
-```
-
-Repeat one more time to generate the final image in a [pixel art style](https://huggingface.co/kohbanye/pixel-art-style):
-
-```py
-pipeline = AutoPipelineForImage2Image.from_pretrained(
- "kohbanye/pixel-art-style", torch_dtype=torch.float16
-)
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-# need to include the token "pixelartstyle" in the prompt to use this checkpoint
-image = pipeline("Astronaut in a jungle, pixelartstyle", image=image).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
-```
-
-### Image-to-upscaler-to-super-resolution
-
-Another way you can chain your image-to-image pipeline is with an upscaler and super-resolution pipeline to really increase the level of details in an image.
-
-Start with an image-to-image pipeline:
-
-```py
-import torch
-from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import make_image_grid, load_image
-
-pipeline = AutoPipelineForImage2Image.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-# prepare image
-url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-init_image = load_image(url)
-
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-
-# pass prompt and image to pipeline
-image_1 = pipeline(prompt, image=init_image, output_type="latent").images[0]
-```
-
-
-
-It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in *latent* space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.
-
-
-
-Chain it to an upscaler pipeline to increase the image resolution:
-
-```py
-from diffusers import StableDiffusionLatentUpscalePipeline
-
-upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
- "stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
-upscaler.enable_model_cpu_offload()
-upscaler.enable_xformers_memory_efficient_attention()
-
-image_2 = upscaler(prompt, image=image_1, output_type="latent").images[0]
-```
-
-Finally, chain it to a super-resolution pipeline to further enhance the resolution:
-
-```py
-from diffusers import StableDiffusionUpscalePipeline
-
-super_res = StableDiffusionUpscalePipeline.from_pretrained(
- "stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
-super_res.enable_model_cpu_offload()
-super_res.enable_xformers_memory_efficient_attention()
-
-image_3 = super_res(prompt, image=image_2).images[0]
-make_image_grid([init_image, image_3.resize((512, 512))], rows=1, cols=2)
-```
-
-## Control image generation
-
-Trying to generate an image that looks exactly the way you want can be difficult, which is why controlled generation techniques and models are so useful. While you can use the `negative_prompt` to partially control image generation, there are more robust methods like prompt weighting and ControlNets.
-
-### Prompt weighting
-
-Prompt weighting allows you to scale the representation of each concept in a prompt. For example, in a prompt like "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", you can choose to increase or decrease the embeddings of "astronaut" and "jungle". The [Compel](https://github.com/damian0815/compel) library provides a simple syntax for adjusting prompt weights and generating the embeddings. You can learn how to create the embeddings in the [Prompt weighting](weighted_prompts) guide.
-
-[`AutoPipelineForImage2Image`] has a `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter where you can pass the embeddings which replaces the `prompt` parameter.
-
-```py
-from diffusers import AutoPipelineForImage2Image
-import torch
-
-pipeline = AutoPipelineForImage2Image.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-image = pipeline(prompt_embeds=prompt_embeds, # generated from Compel
- negative_prompt_embeds=negative_prompt_embeds, # generated from Compel
- image=init_image,
-).images[0]
-```
-
-### ControlNet
-
-ControlNets provide a more flexible and accurate way to control image generation because you can use an additional conditioning image. The conditioning image can be a canny image, depth map, image segmentation, and even scribbles! Whatever type of conditioning image you choose, the ControlNet generates an image that preserves the information in it.
-
-For example, let's condition an image with a depth map to keep the spatial information in the image.
```py
-from diffusers.utils import load_image, make_image_grid
-
-# prepare image
-url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-init_image = load_image(url)
-init_image = init_image.resize((958, 960)) # resize to depth image dimensions
-depth_image = load_image("https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png")
-make_image_grid([init_image, depth_image], rows=1, cols=2)
-```
-
-Load a ControlNet model conditioned on depth maps and the [`AutoPipelineForImage2Image`]:
-
-```py
-from diffusers import ControlNetModel, AutoPipelineForImage2Image
import torch
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
-controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
-pipeline = AutoPipelineForImage2Image.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
)
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-```
+pipe.to("cuda")
-Now generate a new image conditioned on the depth map, initial image, and prompt:
-```py
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-image_control_net = pipeline(prompt, image=init_image, control_image=depth_image).images[0]
-make_image_grid([init_image, depth_image, image_control_net], rows=1, cols=3)
+prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <|image_1|>. The long-sleeve blouse and a pleated skirt are <|image_2|>."
+input_image_1 = load_image("/share/junjie/code/VISTA2/produce_data/laion_net/diffgpt/OmniGen/docs_img/emma.jpeg")
+input_image_2 = load_image("/share/junjie/code/VISTA2/produce_data/laion_net/diffgpt/OmniGen/docs_img/dress.jpg")
+input_images=[input_image_1, input_image_2]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ height=1024,
+ width=1024,
+ guidance_scale=2.5,
+ img_guidance_scale=1.6,
+ generator=torch.Generator(device="cpu").manual_seed(666)).images[0]
```
-
- initial image
+
+ person image
-
- depth image
+
+ clothe image
-
- ControlNet image
+
+ generated image
-Let's apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion) to the image generated from the ControlNet by chaining it with an image-to-image pipeline:
-```py
-pipeline = AutoPipelineForImage2Image.from_pretrained(
- "nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16,
-)
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-prompt = "elden ring style astronaut in a jungle" # include the token "elden ring style" in the prompt
-negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
-
-image_elden_ring = pipeline(prompt, negative_prompt=negative_prompt, image=image_control_net, strength=0.45, guidance_scale=10.5).images[0]
-make_image_grid([init_image, depth_image, image_control_net, image_elden_ring], rows=2, cols=2)
-```
-
-
-
-
-
-## Optimize
-
-Running diffusion models is computationally expensive and intensive, but with a few optimization tricks, it is entirely possible to run them on consumer and free-tier GPUs. For example, you can use a more memory-efficient form of attention such as PyTorch 2.0's [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) or [xFormers](../optimization/xformers) (you can use one or the other, but there's no need to use both). You can also offload the model to the GPU while the other pipeline components wait on the CPU.
-
-```diff
-+ pipeline.enable_model_cpu_offload()
-+ pipeline.enable_xformers_memory_efficient_attention()
-```
-
-With [`torch.compile`](../optimization/torch2.0#torchcompile), you can boost your inference speed even more by wrapping your UNet with it:
+The output image is a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object that can be saved:
```py
-pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
-```
-
-To learn more, take a look at the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
+image.save("generated_image.png")
+```
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md
index 22a14bd95efe..c535f4358155 100644
--- a/docs/source/en/using-diffusers/omnigen.md
+++ b/docs/source/en/using-diffusers/omnigen.md
@@ -320,9 +320,10 @@ Here are some guidelines to help you reduce computational costs when input multi
- With enabling cpu offloading, memory usage is `28 GB`
- `offload_transformer_block=True`:
- - 17G
+ - offload transformer block to reduce memory usage
+ - When enabled, memory usage is under `25 GB`
- `pipe.enable_sequential_cpu_offload()`:
- - 11G
-
+ - significantly reduce memory usage at the cost of slow inference
+ - When enabled, memory usage is under `11 GB`
diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py
index 882065faae3d..3a1072a0d349 100644
--- a/src/diffusers/models/transformers/transformer_omnigen.py
+++ b/src/diffusers/models/transformers/transformer_omnigen.py
@@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union, List
-from dataclasses import dataclass
import torch
import torch.utils.checkpoint
from torch import nn
-from transformers.cache_utils import DynamicCache
from transformers import Phi3Model, Phi3Config
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast
@@ -36,7 +35,6 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
@dataclass
class OmniGen2DModelOutput(Transformer2DModelOutput):
"""
@@ -74,7 +72,7 @@ def evict_previous_layer(self, layer_idx: int):
prev_layer_idx = layer_idx - 1
for name, param in self.layers[prev_layer_idx].named_parameters():
param.data = param.data.to("cpu", non_blocking=True)
-
+
def get_offload_layer(self, layer_idx: int, device: torch.device):
# init stream
if not hasattr(self, "prefetch_stream"):
@@ -83,7 +81,7 @@ def get_offload_layer(self, layer_idx: int, device: torch.device):
# delete previous layer
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx)
-
+
# make sure the current layer is ready
torch.cuda.synchronize(self.prefetch_stream)
@@ -273,7 +271,6 @@ def unpatchify(self, x, h, w):
imgs = x.reshape(shape=(x.shape[0], c, h, w))
return imgs
-
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -337,12 +334,12 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
-
- def get_multimodal_embeddings(self,
- input_ids: torch.Tensor,
- input_img_latents: List[torch.Tensor],
- input_image_sizes: Dict,
- ):
+
+ def get_multimodal_embeddings(self,
+ input_ids: torch.Tensor,
+ input_img_latents: List[torch.Tensor],
+ input_image_sizes: Dict,
+ ):
"""
get the multi-modal conditional embeddings
Args:
@@ -353,7 +350,7 @@ def get_multimodal_embeddings(self,
Returns: torch.Tensor
"""
- input_img_latents = [x.to(self.dtype) for x in input_img_latents]
+ input_img_latents = [x.to(self.dtype) for x in input_img_latents]
condition_tokens = None
if input_ids is not None:
condition_tokens = self.llm.embed_tokens(input_ids)
@@ -384,6 +381,41 @@ def forward(self,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
+ """
+ The [`OmniGenTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ timestep (`torch.LongTensor`):
+ Used to indicate denoising step.
+ input_ids (`torch.LongTensor`):
+ token ids
+ input_img_latents (`torch.FloatTensor`):
+ encoded image latents by VAE
+ input_image_sizes (`dict`):
+ the indices of the input_img_latents in the input_ids
+ attention_mask (`torch.FloatTensor`):
+ mask for self-attention
+ position_ids (`torch.LongTensor`):
+ id to represent position
+ past_key_values (`transformers.cache_utils.Cache`):
+ previous key and value states
+ offload_transformer_block (`bool`, *optional*, defaults to `True`):
+ offload transformer block to cpu
+ attention_kwargs: (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`OmniGen2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`OmniGen2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+
+ """
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
diff --git a/src/diffusers/pipelines/omnigen/kvcache_omnigen.py b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py
index 7f02588ce405..4bf32ae6ae74 100644
--- a/src/diffusers/pipelines/omnigen/kvcache_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py
@@ -1,20 +1,32 @@
-from tqdm import tqdm
+# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import Optional, Dict, Any, Tuple, List
-import gc
import torch
-from transformers.cache_utils import Cache, DynamicCache, OffloadedCache
-
+from transformers.cache_utils import DynamicCache
class OmniGenCache(DynamicCache):
- def __init__(self,
- num_tokens_for_img: int,
- offload_kv_cache: bool=False) -> None:
+ def __init__(self,
+ num_tokens_for_img: int,
+ offload_kv_cache: bool = False) -> None:
if not torch.cuda.is_available():
# print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!")
# offload_kv_cache = False
- raise RuntimeError("OffloadedCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!")
+ raise RuntimeError(
+ "OffloadedCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!")
super().__init__()
self.original_device = []
self.prefetch_stream = torch.cuda.Stream()
@@ -30,19 +42,17 @@ def prefetch_layer(self, layer_idx: int):
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
-
def evict_previous_layer(self, layer_idx: int):
"Moves the previous layer cache to the CPU"
if len(self) > 2:
# We do it on the default stream so it occurs after all earlier computations on these tensors are done
- if layer_idx == 0:
+ if layer_idx == 0:
prev_layer_idx = -1
else:
prev_layer_idx = (layer_idx - 1) % len(self)
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
-
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
if layer_idx < len(self):
@@ -56,7 +66,7 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
torch.cuda.synchronize(self.prefetch_stream)
key_tensor = self.key_cache[layer_idx]
value_tensor = self.value_cache[layer_idx]
-
+
# Prefetch the next layer
self.prefetch_layer((layer_idx + 1) % len(self))
else:
@@ -65,13 +75,13 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
return (key_tensor, value_tensor)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
-
+
def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
@@ -92,13 +102,13 @@ def update(
raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
elif len(self.key_cache) == layer_idx:
# only cache the states for condition tokens
- key_states = key_states[..., :-(self.num_tokens_for_img+1), :]
- value_states = value_states[..., :-(self.num_tokens_for_img+1), :]
+ key_states = key_states[..., :-(self.num_tokens_for_img + 1), :]
+ value_states = value_states[..., :-(self.num_tokens_for_img + 1), :]
- # Update the number of seen tokens
+ # Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
-
+
self.key_cache.append(key_states)
self.value_cache.append(value_states)
self.original_device.append(key_states.device)
diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
index 902453aae0f6..603ef39476e8 100644
--- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
@@ -32,17 +32,15 @@
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
-from .processor_omnigen import OmniGenMultiModalProcessor
from .kvcache_omnigen import OmniGenCache
+from .processor_omnigen import OmniGenMultiModalProcessor
if is_torch_xla_available():
- import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
@@ -62,15 +60,14 @@
"""
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
- scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- sigmas: Optional[List[float]] = None,
- **kwargs,
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
@@ -150,11 +147,11 @@ class OmniGenPipeline(
_callback_tensor_inputs = ["latents", "input_images_latents"]
def __init__(
- self,
- transformer: OmniGenTransformer2DModel,
- scheduler: FlowMatchEulerDiscreteScheduler,
- vae: AutoencoderKL,
- tokenizer: LlamaTokenizer,
+ self,
+ transformer: OmniGenTransformer2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ tokenizer: LlamaTokenizer,
):
super().__init__()
@@ -199,17 +196,16 @@ def encod_input_iamges(
input_img_latents.append(img)
return input_img_latents
-
def check_inputs(
- self,
- prompt,
- input_images,
- height,
- width,
- use_kv_cache,
- offload_kv_cache,
- callback_on_step_end_tensor_inputs=None,
- max_sequence_length=None,
+ self,
+ prompt,
+ input_images,
+ height,
+ width,
+ use_kv_cache,
+ offload_kv_cache,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
):
if input_images is not None:
@@ -219,7 +215,7 @@ def check_inputs(
)
for i in range(len(input_images)):
if input_images[i] is not None:
- if not all(f"<|image_{k+1}|>" in prompt[i] for k in range(len(input_images[i]))):
+ if not all(f"<|image_{k + 1}|>" in prompt[i] for k in range(len(input_images[i]))):
raise ValueError(
f"prompt `{prompt[i]}` doesn't have enough placeholders for the input images `{input_images[i]}`"
)
@@ -228,15 +224,15 @@ def check_inputs(
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
-
+
if use_kv_cache and offload_kv_cache:
if not torch.cuda.is_available():
raise ValueError(
- f"Don't fine avaliable GPUs. `offload_kv_cache` can't be used when there is no GPU. please set it to False: `use_kv_cache=False, offload_kv_cache=False`"
- )
+ f"Don't fine avaliable GPUs. `offload_kv_cache` can't be used when there is no GPU. please set it to False: `use_kv_cache=False, offload_kv_cache=False`"
+ )
if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
@@ -286,15 +282,15 @@ def disable_vae_tiling(self):
self.vae.disable_tiling()
def prepare_latents(
- self,
- batch_size,
- num_channels_latents,
- height,
- width,
- dtype,
- device,
- generator,
- latents=None,
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
):
if latents is not None:
return latents.to(device=device, dtype=dtype)
@@ -327,7 +323,7 @@ def num_timesteps(self):
@property
def interrupt(self):
return self._interrupt
-
+
def enable_transformer_block_cpu_offload(self, device: Union[torch.device, str] = "cuda"):
torch_device = torch.device(device)
for name, param in self.transformer.named_parameters():
@@ -343,29 +339,30 @@ def enable_transformer_block_cpu_offload(self, device: Union[torch.device, str]
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
- self,
- prompt: Union[str, List[str]],
- input_images: Optional[Union[List[str], List[PIL.Image.Image], List[List[str]], List[List[PIL.Image.Image]]]] = None,
- height: Optional[int] = None,
- width: Optional[int] = None,
- num_inference_steps: int = 50,
- max_input_image_size: int = 1024,
- timesteps: List[int] = None,
- guidance_scale: float = 2.5,
- img_guidance_scale: float = 1.6,
- use_kv_cache: bool = True,
- offload_kv_cache: bool = True,
- offload_transformer_block: bool = False,
- use_input_image_size_as_output: bool = False,
- num_images_per_prompt: Optional[int] = 1,
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
- latents: Optional[torch.FloatTensor] = None,
- output_type: Optional[str] = "pil",
- return_dict: bool = True,
- attention_kwargs: Optional[Dict[str, Any]] = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- max_sequence_length: int = 120000,
+ self,
+ prompt: Union[str, List[str]],
+ input_images: Optional[
+ Union[List[str], List[PIL.Image.Image], List[List[str]], List[List[PIL.Image.Image]]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ max_input_image_size: int = 1024,
+ timesteps: List[int] = None,
+ guidance_scale: float = 2.5,
+ img_guidance_scale: float = 1.6,
+ use_kv_cache: bool = True,
+ offload_kv_cache: bool = True,
+ offload_transformer_block: bool = False,
+ use_input_image_size_as_output: bool = False,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 120000,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -444,11 +441,11 @@ def __call__(
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
num_cfg = 2 if input_images is not None else 1
- use_img_cfg = True if input_images is not None else False
+ use_img_cfg = True if input_images is not None else False
if isinstance(prompt, str):
prompt = [prompt]
input_images = [input_images]
-
+
# using Float32 for the VAE doesn't take up much memory but can prevent potential black image outputs.
self.vae.to(torch.float32)
@@ -475,16 +472,15 @@ def __call__(
batch_size = len(prompt)
device = self._execution_device
-
# 3. process multi-modal instructions
if max_input_image_size != self.multimodal_processor.max_image_size:
self.multimodal_processor = OmniGenMultiModalProcessor(self.tokenizer, max_image_size=max_input_image_size)
processed_data = self.multimodal_processor(prompt,
- input_images,
- height=height,
- width=width,
- use_img_cfg=use_img_cfg,
- use_input_image_size_as_output=use_input_image_size_as_output)
+ input_images,
+ height=height,
+ width=width,
+ use_img_cfg=use_img_cfg,
+ use_input_image_size_as_output=use_input_image_size_as_output)
processed_data['input_ids'] = processed_data['input_ids'].to(device)
processed_data['attention_mask'] = processed_data['attention_mask'].to(device)
processed_data['position_ids'] = processed_data['position_ids'].to(device)
@@ -493,7 +489,7 @@ def __call__(
input_img_latents = self.encod_input_iamges(processed_data['input_pixel_values'], device=device)
# 5. Prepare timesteps
- sigmas = np.linspace(1, 0, num_inference_steps+1)[:num_inference_steps]
+ sigmas = np.linspace(1, 0, num_inference_steps + 1)[:num_inference_steps]
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
)
@@ -522,7 +518,7 @@ def __call__(
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * (num_cfg+1))
+ latent_model_input = torch.cat([latents] * (num_cfg + 1))
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
@@ -537,16 +533,20 @@ def __call__(
position_ids=processed_data['position_ids'],
attention_kwargs=attention_kwargs,
past_key_values=cache,
- offload_transformer_block=self.offload_transformer_block if hasattr(self, 'offload_transformer_block') else offload_transformer_block,
+ offload_transformer_block=self.offload_transformer_block if hasattr(self,
+ 'offload_transformer_block') else offload_transformer_block,
return_dict=False,
)
-
+
# if use kv cache, don't need attention mask and position ids of condition tokens for next step
if use_kv_cache:
if processed_data['input_ids'] is not None:
processed_data['input_ids'] = None
- processed_data['attention_mask'] = processed_data['attention_mask'][..., -(num_tokens_for_output_img + 1):, :] # +1 is for the timestep token
- processed_data['position_ids'] = processed_data['position_ids'][:, -(num_tokens_for_output_img + 1):]
+ processed_data['attention_mask'] = processed_data['attention_mask'][...,
+ -(num_tokens_for_output_img + 1):,
+ :] # +1 is for the timestep token
+ processed_data['position_ids'] = processed_data['position_ids'][:,
+ -(num_tokens_for_output_img + 1):]
if num_cfg == 2:
cond, uncond, img_cond = torch.split(noise_pred, len(noise_pred) // 3, dim=0)
@@ -581,4 +581,3 @@ def __call__(
return (image,)
return ImagePipelineOutput(images=image)
-
diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py
index 545c9b001c7d..af52b20b55db 100644
--- a/src/diffusers/pipelines/omnigen/processor_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py
@@ -1,3 +1,17 @@
+# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import re
from typing import Dict, List
@@ -7,7 +21,6 @@
from torchvision import transforms
-
def crop_image(pil_image, max_image_size):
"""
Crop the image so that its height and width does not exceed `max_image_size`,
@@ -291,4 +304,3 @@ def __call__(self, features):
"input_image_sizes": all_image_sizes,
}
return data
-
From 4fef9c8e2f841454739e66764ce6cf2bd041b19d Mon Sep 17 00:00:00 2001
From: shitao <2906698981@qq.com>
Date: Sun, 8 Dec 2024 16:28:12 +0800
Subject: [PATCH 13/55] reformat
---
scripts/convert_omnigen_to_diffusers.py | 214 +++++++++---------
src/diffusers/models/embeddings.py | 4 +-
.../pipelines/omnigen/pipeline_omnigen.py | 6 +-
.../scheduling_flow_match_euler_discrete.py | 3 +-
test.py | 59 -----
5 files changed, 113 insertions(+), 173 deletions(-)
delete mode 100644 test.py
diff --git a/scripts/convert_omnigen_to_diffusers.py b/scripts/convert_omnigen_to_diffusers.py
index 8c9bc8bdb457..4e8b284d2fb7 100644
--- a/scripts/convert_omnigen_to_diffusers.py
+++ b/scripts/convert_omnigen_to_diffusers.py
@@ -1,11 +1,10 @@
import argparse
import os
-os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2'
import torch
-from safetensors.torch import load_file
-from transformers import AutoModel, AutoTokenizer, AutoConfig
from huggingface_hub import snapshot_download
+from safetensors.torch import load_file
+from transformers import AutoTokenizer
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenTransformer2DModel, OmniGenPipeline
@@ -17,9 +16,9 @@ def main(args):
print("Model not found, downloading...")
cache_folder = os.getenv('HF_HUB_CACHE')
args.origin_ckpt_path = snapshot_download(repo_id=args.origin_ckpt_path,
- cache_dir=cache_folder,
- ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5',
- 'model.pt'])
+ cache_dir=cache_folder,
+ ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5',
+ 'model.pt'])
print(f"Downloaded model to {args.origin_ckpt_path}")
ckpt = os.path.join(args.origin_ckpt_path, 'model.safetensors')
@@ -48,7 +47,7 @@ def main(args):
# transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path)
# print(type(transformer_config.__dict__))
# print(transformer_config.__dict__)
-
+
transformer_config = {
"_name_or_path": "Phi-3-vision-128k-instruct",
"architectures": [
@@ -70,104 +69,104 @@ def main(args):
"rms_norm_eps": 1e-05,
"rope_scaling": {
"long_factor": [
- 1.0299999713897705,
- 1.0499999523162842,
- 1.0499999523162842,
- 1.0799999237060547,
- 1.2299998998641968,
- 1.2299998998641968,
- 1.2999999523162842,
- 1.4499999284744263,
- 1.5999999046325684,
- 1.6499998569488525,
- 1.8999998569488525,
- 2.859999895095825,
- 3.68999981880188,
- 5.419999599456787,
- 5.489999771118164,
- 5.489999771118164,
- 9.09000015258789,
- 11.579999923706055,
- 15.65999984741211,
- 15.769999504089355,
- 15.789999961853027,
- 18.360000610351562,
- 21.989999771118164,
- 23.079999923706055,
- 30.009998321533203,
- 32.35000228881836,
- 32.590003967285156,
- 35.56000518798828,
- 39.95000457763672,
- 53.840003967285156,
- 56.20000457763672,
- 57.95000457763672,
- 59.29000473022461,
- 59.77000427246094,
- 59.920005798339844,
- 61.190006256103516,
- 61.96000671386719,
- 62.50000762939453,
- 63.3700065612793,
- 63.48000717163086,
- 63.48000717163086,
- 63.66000747680664,
- 63.850006103515625,
- 64.08000946044922,
- 64.760009765625,
- 64.80001068115234,
- 64.81001281738281,
- 64.81001281738281
+ 1.0299999713897705,
+ 1.0499999523162842,
+ 1.0499999523162842,
+ 1.0799999237060547,
+ 1.2299998998641968,
+ 1.2299998998641968,
+ 1.2999999523162842,
+ 1.4499999284744263,
+ 1.5999999046325684,
+ 1.6499998569488525,
+ 1.8999998569488525,
+ 2.859999895095825,
+ 3.68999981880188,
+ 5.419999599456787,
+ 5.489999771118164,
+ 5.489999771118164,
+ 9.09000015258789,
+ 11.579999923706055,
+ 15.65999984741211,
+ 15.769999504089355,
+ 15.789999961853027,
+ 18.360000610351562,
+ 21.989999771118164,
+ 23.079999923706055,
+ 30.009998321533203,
+ 32.35000228881836,
+ 32.590003967285156,
+ 35.56000518798828,
+ 39.95000457763672,
+ 53.840003967285156,
+ 56.20000457763672,
+ 57.95000457763672,
+ 59.29000473022461,
+ 59.77000427246094,
+ 59.920005798339844,
+ 61.190006256103516,
+ 61.96000671386719,
+ 62.50000762939453,
+ 63.3700065612793,
+ 63.48000717163086,
+ 63.48000717163086,
+ 63.66000747680664,
+ 63.850006103515625,
+ 64.08000946044922,
+ 64.760009765625,
+ 64.80001068115234,
+ 64.81001281738281,
+ 64.81001281738281
],
"short_factor": [
- 1.05,
- 1.05,
- 1.05,
- 1.1,
- 1.1,
- 1.1,
- 1.2500000000000002,
- 1.2500000000000002,
- 1.4000000000000004,
- 1.4500000000000004,
- 1.5500000000000005,
- 1.8500000000000008,
- 1.9000000000000008,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.000000000000001,
- 2.1000000000000005,
- 2.1000000000000005,
- 2.2,
- 2.3499999999999996,
- 2.3499999999999996,
- 2.3499999999999996,
- 2.3499999999999996,
- 2.3999999999999995,
- 2.3999999999999995,
- 2.6499999999999986,
- 2.6999999999999984,
- 2.8999999999999977,
- 2.9499999999999975,
- 3.049999999999997,
- 3.049999999999997,
- 3.049999999999997
+ 1.05,
+ 1.05,
+ 1.05,
+ 1.1,
+ 1.1,
+ 1.1,
+ 1.2500000000000002,
+ 1.2500000000000002,
+ 1.4000000000000004,
+ 1.4500000000000004,
+ 1.5500000000000005,
+ 1.8500000000000008,
+ 1.9000000000000008,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.1000000000000005,
+ 2.1000000000000005,
+ 2.2,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3999999999999995,
+ 2.3999999999999995,
+ 2.6499999999999986,
+ 2.6999999999999984,
+ 2.8999999999999977,
+ 2.9499999999999975,
+ 3.049999999999997,
+ 3.049999999999997,
+ 3.049999999999997
],
"type": "su"
},
@@ -179,7 +178,7 @@ def main(args):
"use_cache": True,
"vocab_size": 32064,
"_attn_implementation": "sdpa"
- }
+ }
transformer = OmniGenTransformer2DModel(
transformer_config=transformer_config,
patch_size=2,
@@ -198,7 +197,6 @@ def main(args):
tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path)
-
pipeline = OmniGenPipeline(
tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler
)
@@ -209,10 +207,12 @@ def main(args):
parser = argparse.ArgumentParser()
parser.add_argument(
- "--origin_ckpt_path", default="Shitao/OmniGen-v1", type=str, required=False, help="Path to the checkpoint to convert."
+ "--origin_ckpt_path", default="Shitao/OmniGen-v1", type=str, required=False,
+ help="Path to the checkpoint to convert."
)
- parser.add_argument("--dump_path", default="/share/shitao/repos/OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline.")
+ parser.add_argument("--dump_path", default="OmniGen-v1-diffusers", type=str, required=False,
+ help="Path to the output pipeline.")
args = parser.parse_args()
main(args)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 720a48f3f747..412af10b834c 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -358,8 +358,8 @@ def forward(self,
):
"""
Args:
- latent:
- is_input_image:
+ latent: encoded image latents
+ is_input_image: use input_image_proj or output_image_proj
padding_latent: When sizes of target images are inconsistent, use `padding_latent` to maintain consistent sequence length.
Returns: torch.Tensor
diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
index 603ef39476e8..44fed3490843 100644
--- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
@@ -49,13 +49,13 @@
>>> import torch
>>> from diffusers import OmniGenPipeline
- >>> pipe = OmniGenPipeline.from_pretrained("****", torch_dtype=torch.bfloat16)
+ >>> pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
- >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
- >>> image.save("flux.png")
+ >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
+ >>> image.save("t2i.png")
```
"""
diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
index 7e01160b8626..56fe71929d13 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
@@ -212,7 +212,7 @@ def set_timesteps(
else:
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
- self.timesteps = timesteps.to(device=device)
+ self.timesteps = timesteps.to(device=device)
self.sigmas = sigmas
self._step_index = None
self._begin_index = None
@@ -300,7 +300,6 @@ def step(
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
-
prev_sample = sample + (sigma_next - sigma) * model_output
# Cast sample back to model compatible dtype
diff --git a/test.py b/test.py
deleted file mode 100644
index b27d99a1066b..000000000000
--- a/test.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import os
-os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2'
-
-# from huggingface_hub import snapshot_download
-
-# from diffusers.models import OmniGenTransformer2DModel
-# from transformers import Phi3Model, Phi3Config
-
-
-# from safetensors.torch import load_file
-
-# model_name = "Shitao/OmniGen-v1"
-# config = Phi3Config.from_pretrained("Shitao/OmniGen-v1")
-# model = OmniGenTransformer2DModel(transformer_config=config)
-# cache_folder = os.getenv('HF_HUB_CACHE')
-# model_name = snapshot_download(repo_id=model_name,
-# cache_dir=cache_folder,
-# ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
-# print(model_name)
-# model_path = os.path.join(model_name, 'model.safetensors')
-# ckpt = load_file(model_path, 'cpu')
-
-
-# mapping_dict = {
-# "pos_embed": "patch_embedding.pos_embed",
-# "x_embedder.proj.weight": "patch_embedding.output_image_proj.weight",
-# "x_embedder.proj.bias": "patch_embedding.output_image_proj.bias",
-# "input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight",
-# "input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias",
-# "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
-# "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
-# "final_layer.linear.weight": "proj_out.weight",
-# "final_layer.linear.bias": "proj_out.bias",
-
-# }
-
-# new_ckpt = {}
-# for k, v in ckpt.items():
-# # new_ckpt[k] = v
-# if k in mapping_dict:
-# new_ckpt[mapping_dict[k]] = v
-# else:
-# new_ckpt[k] = v
-
-
-
-# model.load_state_dict(new_ckpt)
-
-
-from tests.pipelines.omnigen.test_pipeline_omnigen import OmniGenPipelineFastTests, OmniGenPipelineSlowTests
-
-test1 = OmniGenPipelineFastTests()
-test1.test_inference()
-
-test2 = OmniGenPipelineSlowTests()
-test2.test_omnigen_inference()
-
-
-
From f2fc182c89dcefc21fe8e65ea731112203867313 Mon Sep 17 00:00:00 2001
From: shitao <2906698981@qq.com>
Date: Sun, 8 Dec 2024 16:29:38 +0800
Subject: [PATCH 14/55] reformat
---
src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
index 56fe71929d13..d4a970720f8e 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
@@ -300,6 +300,7 @@ def step(
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
+
prev_sample = sample + (sigma_next - sigma) * model_output
# Cast sample back to model compatible dtype
From 5f3148d801e6d46ce05c42a965fd78ae1c60f350 Mon Sep 17 00:00:00 2001
From: shitao <2906698981@qq.com>
Date: Sun, 8 Dec 2024 16:31:52 +0800
Subject: [PATCH 15/55] reformat
---
src/diffusers/pipelines/lumina/pipeline_lumina.py | 2 +-
.../schedulers/scheduling_flow_match_euler_discrete.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py
index 1db3d58808d2..018f2e8bf1bc 100644
--- a/src/diffusers/pipelines/lumina/pipeline_lumina.py
+++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py
@@ -777,7 +777,7 @@ def __call__(
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps,
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
)
# 5. Prepare latents.
diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
index d4a970720f8e..c1096dbe0c29 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
@@ -300,7 +300,7 @@ def step(
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
-
+
prev_sample = sample + (sigma_next - sigma) * model_output
# Cast sample back to model compatible dtype
From 178d377622654de02c55d11f6cda73949ca06f8d Mon Sep 17 00:00:00 2001
From: shitao <2906698981@qq.com>
Date: Sun, 8 Dec 2024 17:09:31 +0800
Subject: [PATCH 16/55] update docs
---
docs/source/en/using-diffusers/multimodal2img.md | 4 +++-
docs/source/en/using-diffusers/omnigen.md | 1 +
2 files changed, 4 insertions(+), 1 deletion(-)
diff --git a/docs/source/en/using-diffusers/multimodal2img.md b/docs/source/en/using-diffusers/multimodal2img.md
index 10567d49ae19..1aabb99d5879 100644
--- a/docs/source/en/using-diffusers/multimodal2img.md
+++ b/docs/source/en/using-diffusers/multimodal2img.md
@@ -16,7 +16,9 @@ specific language governing permissions and limitations under the License.
-Multimodal instructions mean you can input any sequence of mixed text and images to guide image generation. You can input multiple images and use prompts to describe the desired output. This approach is more flexible than using only text or images.
+Multimodal instructions mean you can input arbitrarily interleaved text and image inputs as conditions to guide image generation.
+You can input multiple images and use prompts to describe the desired output.
+This approach is more flexible than using only text or images.
## Examples
diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md
index c535f4358155..833199641644 100644
--- a/docs/source/en/using-diffusers/omnigen.md
+++ b/docs/source/en/using-diffusers/omnigen.md
@@ -274,6 +274,7 @@ image = pipe(
guidance_scale=2.5,
img_guidance_scale=1.6,
generator=torch.Generator(device="cpu").manual_seed(666)).images[0]
+image
```