From 36eee4022dbbf54fbe2d8c8d42dec6d5952677a8 Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Sat, 30 Nov 2024 18:01:05 +0800 Subject: [PATCH 01/55] OmniGen model.py --- src/diffusers/models/embeddings.py | 127 +++ src/diffusers/models/normalization.py | 2 +- .../transformers/transformer_omnigen.py | 345 ++++++++ src/diffusers/pipelines/omnigen/__init__.py | 67 ++ .../pipelines/omnigen/pipeline_omnigen.py | 774 ++++++++++++++++++ .../pipelines/omnigen/pipeline_output.py | 24 + 6 files changed, 1338 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/models/transformers/transformer_omnigen.py create mode 100644 src/diffusers/pipelines/omnigen/__init__.py create mode 100644 src/diffusers/pipelines/omnigen/pipeline_omnigen.py create mode 100644 src/diffusers/pipelines/omnigen/pipeline_output.py diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 80775d477c0d..1eedf55933e0 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -288,6 +288,91 @@ def forward(self, latent): return (latent + pos_embed).to(latent.dtype) + +class OmniGenPatchEmbed(nn.Module): + """2D Image to Patch Embedding with support for OmniGen.""" + + def __init__( + self, + patch_size: int =2, + in_channels: int =4, + embed_dim: int =768, + bias: bool =True, + interpolation_scale: float =1, + pos_embed_max_size: int =192, + base_size: int =64, + ): + super().__init__() + + self.output_image_proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + self.input_image_proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + + self.patch_size = patch_size + self.interpolation_scale = interpolation_scale + self.pos_embed_max_size = pos_embed_max_size + + pos_embed = get_2d_sincos_pos_embed( + embed_dim, self.pos_embed_max_size, base_size=base_size, interpolation_scale=self.interpolation_scale + ) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True) + + def cropped_pos_embed(self, height, width): + """Crops positional embeddings for SD3 compatibility.""" + if self.pos_embed_max_size is None: + raise ValueError("`pos_embed_max_size` must be set for cropping.") + + height = height // self.patch_size + width = width // self.patch_size + if height > self.pos_embed_max_size: + raise ValueError( + f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." + ) + if width > self.pos_embed_max_size: + raise ValueError( + f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." + ) + + top = (self.pos_embed_max_size - height) // 2 + left = (self.pos_embed_max_size - width) // 2 + spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) + spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + return spatial_pos_embed + + def patch_embeddings(self, latent, is_input_image: bool): + if is_input_image: + latent = self.input_image_proj(latent) + else: + latent = self.output_image_proj(latent) + latent = latent.flatten(2).transpose(1, 2) + return latent + + def forward(self, latent, is_input_image: bool, padding_latent=None): + if isinstance(latent, list): + if padding_latent is None: + padding_latent = [None] * len(latent) + patched_latents, num_tokens, shapes = [], [], [] + for sub_latent, padding in zip(latent, padding_latent): + height, width = sub_latent.shape[-2:] + sub_latent = self.patch_embeddings(sub_latent, is_input_image) + pos_embed = self.cropped_pos_embed(height, width) + sub_latent = sub_latent + pos_embed + if padding is not None: + sub_latent = torch.cat([sub_latent, padding], dim=-2) + patched_latents.append(sub_latent) + else: + height, width = latent.shape[-2:] + pos_embed = self.cropped_pos_embed(height, width) + latent = self.patch_embeddings(latent, is_input_image) + latent = latent + pos_embed + + return latent + + class LuminaPatchEmbed(nn.Module): """2D Image to Patch Embedding with support for Lumina-T2X""" @@ -935,6 +1020,48 @@ def forward(self, timesteps): return t_emb +class OmniGenTimestepEmbed(nn.Module): + """ + Embeds scalar timesteps into vector representations for OmniGen + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, dtype=torch.float32): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + + class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 817b3fff2ea6..d19e79ad3b52 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -71,7 +71,7 @@ def forward( if self.chunk_dim == 1: # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the - # other if-branch. This branch is specific to CogVideoX for now. + # other if-branch. This branch is specific to CogVideoX and OmniGen for now. shift, scale = temb.chunk(2, dim=1) shift = shift[:, None, :] scale = scale[:, None, :] diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py new file mode 100644 index 000000000000..a926c6dd43d3 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -0,0 +1,345 @@ +# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Tuple, Union, List + +import torch +from torch import nn +import torch.utils.checkpoint + +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers import Phi3Model, Phi3Config +from transformers.cache_utils import Cache, DynamicCache + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...utils import logging +from ..attention_processor import AttentionProcessor +from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero +from ..embeddings import OmniGenPatchEmbed, OmniGenTimestepEmbed +from ..modeling_utils import ModelMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + + +class OmniGenBaseTransformer(Phi3Model): + """ + Transformer used in OmniGen. The transformer block is from Ph3, and only modify the attention mask. + References: [OmniGen](https://arxiv.org/pdf/2409.11340) + + Parameters: + config: Phi3Config + """ + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + offload_model: Optional[bool] = False, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if attention_mask is not None and attention_mask.dim() == 3: + dtype = inputs_embeds.dtype + min_dtype = torch.finfo(dtype).min + attention_mask = (1 - attention_mask) * min_dtype + attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype) + else: + raise Exception("attention_mask parameter was unavailable or invalid") + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + layer_idx = -1 + for decoder_layer in self.layers: + layer_idx += 1 + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class OmniGenTransformer(ModelMixin, ConfigMixin, PeftAdapterMixin): + """ + The Transformer model introduced in OmniGen. + + Reference: https://arxiv.org/pdf/2409.11340 + + Parameters: + patch_size (`int`, defaults to 2): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. + num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. + guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. + """ + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + transformer_config: Phi3Config, + patch_size=2, + in_channels=4, + pos_embed_max_size: int = 192, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.patch_size = patch_size + self.pos_embed_max_size = pos_embed_max_size + + hidden_size = transformer_config.hidden_size + + self.patch_embedding = OmniGenPatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, pos_embed_max_size=pos_embed_max_size) + + self.time_token = OmniGenTimestepEmbed(hidden_size) + self.t_embedder = OmniGenTimestepEmbed(hidden_size) + + self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1) + self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True) + + self.llm = OmniGenBaseTransformer(config=transformer_config) + self.llm.config.use_cache = False + + + def unpatchify(self, x, h, w): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + + x = x.reshape( + shape=(x.shape[0], h // self.patch_size, w // self.patch_size, self.patch_size, self.patch_size, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h, w)) + return imgs + + + def prepare_condition_embeddings(self, input_ids, input_img_latents, input_image_sizes, padding_latent): + condition_embeds = None + if input_img_latents is not None: + input_latents = self.patch_embedding(input_img_latents, is_input_images=True, padding_latent=padding_latent) + if input_ids is not None: + condition_embeds = self.llm.embed_tokens(input_ids).clone() + input_img_inx = 0 + for b_inx in input_image_sizes.keys(): + for start_inx, end_inx in input_image_sizes[b_inx]: + condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx] + input_img_inx += 1 + if input_img_latents is not None: + assert input_img_inx == len(input_latents) + return condition_embeds + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward(self, + hidden_states, + timestep, + input_ids, + input_img_latents, + input_image_sizes, + attention_mask, + position_ids, + padding_latent=None, + past_key_values=None, + return_past_key_values=True, + offload_model: bool = False): + + height, width = hidden_states.size(-2) + hidden_states = self.patch_embedding(hidden_states, is_input_image=False) + num_tokens_for_output_image = hidden_states.size(1) + + time_token = self.time_token(timestep, dtype=hidden_states.dtype).unsqueeze(1) + + condition_embeds = self.prepare_condition_embeddings(input_ids, input_img_latents, input_image_sizes, padding_latent) + if condition_embeds is not None: + input_emb = torch.cat([condition_embeds, time_token, hidden_states], dim=1) + else: + input_emb = torch.cat([time_token, hidden_states], dim=1) + output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, + past_key_values=past_key_values, offload_model=offload_model) + output, past_key_values = output.last_hidden_state, output.past_key_values + + image_embedding = output[:, -num_tokens_for_output_image:] + time_emb = self.t_embedder(timestep, dtype=hidden_states.dtype) + x = self.final_layer(image_embedding, time_emb) + latents = self.unpatchify(x, height, width) + + if return_past_key_values: + return latents, past_key_values + return latents + + + + + diff --git a/src/diffusers/pipelines/omnigen/__init__.py b/src/diffusers/pipelines/omnigen/__init__.py new file mode 100644 index 000000000000..3570368a5ca1 --- /dev/null +++ b/src/diffusers/pipelines/omnigen/__init__.py @@ -0,0 +1,67 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_flux"] = ["ReduxImageEncoder"] + _import_structure["pipeline_flux"] = ["FluxPipeline"] + _import_structure["pipeline_flux_control"] = ["FluxControlPipeline"] + _import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"] + _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] + _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"] + _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"] + _import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"] + _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] + _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] + _import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modeling_flux import ReduxImageEncoder + from .pipeline_flux import FluxPipeline + from .pipeline_flux_control import FluxControlPipeline + from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline + from .pipeline_flux_controlnet import FluxControlNetPipeline + from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline + from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline + from .pipeline_flux_fill import FluxFillPipeline + from .pipeline_flux_img2img import FluxImg2ImgPipeline + from .pipeline_flux_inpaint import FluxInpaintPipeline + from .pipeline_flux_prior_redux import FluxPriorReduxPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py new file mode 100644 index 000000000000..f6f3752fe97d --- /dev/null +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -0,0 +1,774 @@ +# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import OmniGenPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import OmniGenPipeline + + >>> pipe = OmniGenPipeline.from_pretrained("****", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("flux.png") + ``` +""" + + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class OmniGenPipeline( + DiffusionPipeline, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The OmniGen pipeline for multimodal-to-image generation. + + Reference: https://arxiv.org/pdf/2409.11340 + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + tokenizer: CLIPTokenizer, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/omnigen/pipeline_output.py b/src/diffusers/pipelines/omnigen/pipeline_output.py new file mode 100644 index 000000000000..40ff24199900 --- /dev/null +++ b/src/diffusers/pipelines/omnigen/pipeline_output.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class OmniGenPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + + From bbe2b98e03f44e895b63555c22cbe5a44fca36b6 Mon Sep 17 00:00:00 2001 From: staoxiao <2906698981@qq.com> Date: Sat, 30 Nov 2024 22:14:27 +0800 Subject: [PATCH 02/55] update OmniGenTransformerModel --- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_omnigen.py | 5 +- test.py | 52 +++++++++++++++++++ 5 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 test.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a4749af5f61b..2d5cd10fd66f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -108,6 +108,7 @@ "MotionAdapter", "MultiAdapter", "MultiControlNetModel", + "OmniGenTransformerModel", "PixArtTransformer2DModel", "PriorTransformer", "SD3ControlNetModel", @@ -599,6 +600,7 @@ MotionAdapter, MultiAdapter, MultiControlNetModel, + OmniGenTransformerModel, PixArtTransformer2DModel, PriorTransformer, SD3ControlNetModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 65e2418ac794..faf9b97e827c 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -66,6 +66,7 @@ _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] + _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformerModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] @@ -125,6 +126,7 @@ LatteTransformer3DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, + OmniGenTransformerModel, PixArtTransformer2DModel, PriorTransformer, SD3Transformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index a2c087d708a4..7d0fd364a17d 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -20,3 +20,4 @@ from .transformer_mochi import MochiTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_temporal import TransformerTemporalModel + from .transformer_omnigen import OmniGenTransformerModel diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index a926c6dd43d3..2ea18512724c 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -26,7 +26,7 @@ from ...loaders import PeftAdapterMixin from ...utils import logging from ..attention_processor import AttentionProcessor -from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero +from ..normalization import AdaLayerNorm from ..embeddings import OmniGenPatchEmbed, OmniGenTimestepEmbed from ..modeling_utils import ModelMixin @@ -162,7 +162,7 @@ def forward( ) -class OmniGenTransformer(ModelMixin, ConfigMixin, PeftAdapterMixin): +class OmniGenTransformerModel(ModelMixin, ConfigMixin, PeftAdapterMixin): """ The Transformer model introduced in OmniGen. @@ -343,3 +343,4 @@ def forward(self, + diff --git a/test.py b/test.py new file mode 100644 index 000000000000..e110b93a7b74 --- /dev/null +++ b/test.py @@ -0,0 +1,52 @@ +import os +os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2' + +from huggingface_hub import snapshot_download + +from diffusers.models import OmniGenTransformerModel +from transformers import Phi3Model, Phi3Config + + +from safetensors.torch import load_file + +model_name = "Shitao/OmniGen-v1" +config = Phi3Config.from_pretrained("Shitao/OmniGen-v1") +model = OmniGenTransformerModel(transformer_config=config) +cache_folder = os.getenv('HF_HUB_CACHE') +model_name = snapshot_download(repo_id=model_name, + cache_dir=cache_folder, + ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) +print(model_name) +model_path = os.path.join(model_name, 'model.safetensors') +ckpt = load_file(model_path, 'cpu') + + +mapping_dict = { + "pos_embed": "patch_embedding.pos_embed", + "x_embedder.proj.weight": "patch_embedding.output_image_proj.weight", + "x_embedder.proj.bias": "patch_embedding.output_image_proj.bias", + "input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight", + "input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias", + "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight", + "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias", + "final_layer.linear.weight": "proj_out.weight", + "final_layer.linear.bias": "proj_out.bias", + +} + +new_ckpt = {} +for k, v in ckpt.items(): + # new_ckpt[k] = v + if k in mapping_dict: + new_ckpt[mapping_dict[k]] = v + else: + new_ckpt[k] = v + + + +model.load_state_dict(new_ckpt) + + + + + From b839590eb3a629cd1f87f84b8128593ddbca9adf Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Mon, 2 Dec 2024 17:34:58 +0800 Subject: [PATCH 03/55] omnigen pipeline --- scripts/convert_omnigen_to_diffusers.py | 73 +++ src/diffusers/__init__.py | 6 +- src/diffusers/models/__init__.py | 4 +- src/diffusers/models/embeddings.py | 17 +- src/diffusers/models/transformers/__init__.py | 2 +- .../transformers/transformer_omnigen.py | 149 +++-- .../pipelines/lumina/pipeline_lumina.py | 4 +- src/diffusers/pipelines/omnigen/__init__.py | 37 +- .../pipelines/omnigen/kvcache_omnigen.py | 106 ++++ .../pipelines/omnigen/pipeline_omnigen.py | 533 ++++++------------ .../pipelines/omnigen/pipeline_output.py | 24 - .../pipelines/omnigen/processor_omnigen.py | 295 ++++++++++ src/diffusers/utils/dummy_pt_objects.py | 14 + .../dummy_torch_and_transformers_objects.py | 15 + test.py | 4 +- 15 files changed, 832 insertions(+), 451 deletions(-) create mode 100644 scripts/convert_omnigen_to_diffusers.py create mode 100644 src/diffusers/pipelines/omnigen/kvcache_omnigen.py delete mode 100644 src/diffusers/pipelines/omnigen/pipeline_output.py create mode 100644 src/diffusers/pipelines/omnigen/processor_omnigen.py diff --git a/scripts/convert_omnigen_to_diffusers.py b/scripts/convert_omnigen_to_diffusers.py new file mode 100644 index 000000000000..2b02f6a064dc --- /dev/null +++ b/scripts/convert_omnigen_to_diffusers.py @@ -0,0 +1,73 @@ +import argparse +import os + +import torch +from safetensors.torch import load_file +from transformers import AutoModel, AutoTokenizer, AutoConfig + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenTransformer2DModel, OmniGenPipeline + + +def main(args): + # checkpoint from https://huggingface.co/Shitao/OmniGen-v1 + ckpt = load_file(args.origin_ckpt_path, device="cpu") + + mapping_dict = { + "pos_embed": "patch_embedding.pos_embed", + "x_embedder.proj.weight": "patch_embedding.output_image_proj.weight", + "x_embedder.proj.bias": "patch_embedding.output_image_proj.bias", + "input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight", + "input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias", + "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight", + "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias", + "final_layer.linear.weight": "proj_out.weight", + "final_layer.linear.bias": "proj_out.bias", + + } + + converted_state_dict = {} + for k, v in ckpt.items(): + # new_ckpt[k] = v + if k in mapping_dict: + converted_state_dict[mapping_dict[k]] = v + else: + converted_state_dict[k] = v + + transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path) + + # Lumina-Next-SFT 2B + transformer = OmniGenTransformer2DModel( + transformer_config=transformer_config, + patch_size=2, + in_channels=4, + pos_embed_max_size=192, + ) + transformer.load_state_dict(converted_state_dict, strict=True) + + num_model_params = sum(p.numel() for p in transformer.parameters()) + print(f"Total number of transformer parameters: {num_model_params}") + + scheduler = FlowMatchEulerDiscreteScheduler() + + vae = AutoencoderKL.from_pretrained(args.origin_ckpt_path, torch_dtype=torch.float32) + + tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path) + + + pipeline = OmniGenPipeline( + tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler + ) + pipeline.save_pretrained(args.dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") + + args = parser.parse_args() + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2d5cd10fd66f..c7868cf1c304 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -108,7 +108,7 @@ "MotionAdapter", "MultiAdapter", "MultiControlNetModel", - "OmniGenTransformerModel", + "OmniGenTransformer2DModel", "PixArtTransformer2DModel", "PriorTransformer", "SD3ControlNetModel", @@ -321,6 +321,7 @@ "MarigoldNormalsPipeline", "MochiPipeline", "MusicLDMPipeline", + "OmniGenPipeline", "PaintByExamplePipeline", "PIAPipeline", "PixArtAlphaPipeline", @@ -600,7 +601,7 @@ MotionAdapter, MultiAdapter, MultiControlNetModel, - OmniGenTransformerModel, + OmniGenTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, SD3ControlNetModel, @@ -792,6 +793,7 @@ MarigoldNormalsPipeline, MochiPipeline, MusicLDMPipeline, + OmniGenPipeline, PaintByExamplePipeline, PIAPipeline, PixArtAlphaPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index faf9b97e827c..3ce71406059b 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -66,7 +66,7 @@ _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] - _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformerModel"] + _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] @@ -126,7 +126,7 @@ LatteTransformer3DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, - OmniGenTransformerModel, + OmniGenTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, SD3Transformer2DModel, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 1eedf55933e0..9681a84b8878 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -351,7 +351,20 @@ def patch_embeddings(self, latent, is_input_image: bool): latent = latent.flatten(2).transpose(1, 2) return latent - def forward(self, latent, is_input_image: bool, padding_latent=None): + def forward(self, + latent: torch.Tensor, + is_input_image: bool, + padding_latent: torch.Tensor = None + ): + """ + Args: + latent: + is_input_image: + padding_latent: When sizes of target images are inconsistent, use `padding_latent` to maintain consistent sequence length. + + Returns: torch.Tensor + + """ if isinstance(latent, list): if padding_latent is None: padding_latent = [None] * len(latent) @@ -362,7 +375,7 @@ def forward(self, latent, is_input_image: bool, padding_latent=None): pos_embed = self.cropped_pos_embed(height, width) sub_latent = sub_latent + pos_embed if padding is not None: - sub_latent = torch.cat([sub_latent, padding], dim=-2) + sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2) patched_latents.append(sub_latent) else: height, width = latent.shape[-2:] diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 7d0fd364a17d..9770ded5e31e 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -20,4 +20,4 @@ from .transformer_mochi import MochiTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_temporal import TransformerTemporalModel - from .transformer_omnigen import OmniGenTransformerModel + from .transformer_omnigen import OmniGenTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index 2ea18512724c..2fa97146f1d3 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -12,29 +12,47 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Tuple, Union, List +from typing import Any, Dict, Optional, Tuple, Union, List +from dataclasses import dataclass import torch -from torch import nn import torch.utils.checkpoint - -from transformers.modeling_outputs import BaseModelOutputWithPast +from torch import nn +from transformers.cache_utils import DynamicCache from transformers import Phi3Model, Phi3Config from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_outputs import BaseModelOutputWithPast from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import logging +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers from ..attention_processor import AttentionProcessor -from ..normalization import AdaLayerNorm from ..embeddings import OmniGenPatchEmbed, OmniGenTimestepEmbed +from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name +@dataclass +class OmniGen2DModelOutput(Transformer2DModelOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + past_key_values (`transformers.cache_utils.DynamicCache`) + """ + + sample: "torch.Tensor" # noqa: F821 + past_key_values: "DynamicCache" + + class OmniGenBaseTransformer(Phi3Model): """ Transformer used in OmniGen. The transformer block is from Ph3, and only modify the attention mask. @@ -44,6 +62,37 @@ class OmniGenBaseTransformer(Phi3Model): config: Phi3Config """ + def prefetch_layer(self, layer_idx: int, device: torch.device): + "Starts prefetching the next layer cache" + with torch.cuda.stream(self.prefetch_stream): + # Prefetch next layer tensors to GPU + for name, param in self.layers[layer_idx].named_parameters(): + param.data = param.data.to(device, non_blocking=True) + + def evict_previous_layer(self, layer_idx: int): + "Moves the previous layer cache to the CPU" + prev_layer_idx = layer_idx - 1 + for name, param in self.layers[prev_layer_idx].named_parameters(): + param.data = param.data.to("cpu", non_blocking=True) + + def get_offload_layer(self, layer_idx: int, device: torch.device): + # init stream + if not hasattr(self, "prefetch_stream"): + self.prefetch_stream = torch.cuda.Stream() + + # delete previous layer + # main stream sync shouldn't be necessary since all computation on iter i-1 is finished by iter i + # torch.cuda.current_stream().synchronize() + # avoid extra eviction of last layer + if layer_idx > 0: + self.evict_previous_layer(layer_idx) + + # make sure the current layer is ready + self.prefetch_stream.synchronize() + + # load next layer + self.prefetch_layer((layer_idx + 1) % len(self.layers), device) + def forward( self, input_ids: torch.LongTensor = None, @@ -56,7 +105,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - offload_model: Optional[bool] = False, + offload_transformer_block: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -124,6 +173,13 @@ def forward( cache_position, ) else: + if offload_transformer_block and not self.training: + if not not torch.cuda.is_available(): + logger.warning_once( + "We don't detecte any available GPU, so diable `offload_transformer_block`" + ) + else: + self.get_offload_layer(layer_idx, device=inputs_embeds.device) layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, @@ -162,7 +218,7 @@ def forward( ) -class OmniGenTransformerModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): """ The Transformer model introduced in OmniGen. @@ -197,7 +253,10 @@ def __init__( hidden_size = transformer_config.hidden_size - self.patch_embedding = OmniGenPatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, pos_embed_max_size=pos_embed_max_size) + self.patch_embedding = OmniGenPatchEmbed(patch_size=patch_size, + in_channels=in_channels, + embed_dim=hidden_size, + pos_embed_max_size=pos_embed_max_size) self.time_token = OmniGenTimestepEmbed(hidden_size) self.t_embedder = OmniGenTimestepEmbed(hidden_size) @@ -208,7 +267,6 @@ def __init__( self.llm = OmniGenBaseTransformer(config=transformer_config) self.llm.config.use_cache = False - def unpatchify(self, x, h, w): """ x: (N, T, patch_size**2 * C) @@ -222,11 +280,10 @@ def unpatchify(self, x, h, w): imgs = x.reshape(shape=(x.shape[0], c, h, w)) return imgs - - def prepare_condition_embeddings(self, input_ids, input_img_latents, input_image_sizes, padding_latent): + def prepare_condition_embeddings(self, input_ids, input_img_latents, input_image_sizes): condition_embeds = None if input_img_latents is not None: - input_latents = self.patch_embedding(input_img_latents, is_input_images=True, padding_latent=padding_latent) + input_latents = self.patch_embedding(input_img_latents, is_input_images=True) if input_ids is not None: condition_embeds = self.llm.embed_tokens(input_ids).clone() input_img_inx = 0 @@ -303,44 +360,54 @@ def _set_gradient_checkpointing(self, module, value=False): module.gradient_checkpointing = value def forward(self, - hidden_states, - timestep, - input_ids, - input_img_latents, - input_image_sizes, - attention_mask, - position_ids, - padding_latent=None, - past_key_values=None, - return_past_key_values=True, - offload_model: bool = False): - - height, width = hidden_states.size(-2) + hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + condition_tokens: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: DynamicCache = None, + offload_transformer_block: bool = False, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + height, width = hidden_states.size(-2) hidden_states = self.patch_embedding(hidden_states, is_input_image=False) num_tokens_for_output_image = hidden_states.size(1) time_token = self.time_token(timestep, dtype=hidden_states.dtype).unsqueeze(1) - condition_embeds = self.prepare_condition_embeddings(input_ids, input_img_latents, input_image_sizes, padding_latent) - if condition_embeds is not None: - input_emb = torch.cat([condition_embeds, time_token, hidden_states], dim=1) + if condition_tokens is not None: + input_emb = torch.cat([condition_tokens, time_token, hidden_states], dim=1) else: input_emb = torch.cat([time_token, hidden_states], dim=1) - output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, - past_key_values=past_key_values, offload_model=offload_model) + output = self.llm(inputs_embeds=input_emb, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + offload_transformer_block=offload_transformer_block) output, past_key_values = output.last_hidden_state, output.past_key_values image_embedding = output[:, -num_tokens_for_output_image:] time_emb = self.t_embedder(timestep, dtype=hidden_states.dtype) x = self.final_layer(image_embedding, time_emb) - latents = self.unpatchify(x, height, width) - - if return_past_key_values: - return latents, past_key_values - return latents - - - - - + output = self.unpatchify(x, height, width) + if not return_dict: + return (output, past_key_values) + return OmniGen2DModelOutput(sample=output, past_key_values=past_key_values) diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 018f2e8bf1bc..296ae5303b20 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -25,7 +25,7 @@ from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL from ...models.embeddings import get_2d_rotary_pos_embed_lumina -from ...models.transformers.lumina_nextdit2d import LuminaNextDiT2DModel +from ...models.transformers.lumina_nextdit2d import LuminaextDiT2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( BACKENDS_MAPPING, @@ -777,7 +777,7 @@ def __call__( # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, device, timesteps, ) # 5. Prepare latents. diff --git a/src/diffusers/pipelines/omnigen/__init__.py b/src/diffusers/pipelines/omnigen/__init__.py index 3570368a5ca1..557e7c08dc22 100644 --- a/src/diffusers/pipelines/omnigen/__init__.py +++ b/src/diffusers/pipelines/omnigen/__init__.py @@ -11,8 +11,8 @@ _dummy_objects = {} -_additional_imports = {} -_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]} +_import_structure = {} + try: if not (is_transformers_available() and is_torch_available()): @@ -22,35 +22,20 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["modeling_flux"] = ["ReduxImageEncoder"] - _import_structure["pipeline_flux"] = ["FluxPipeline"] - _import_structure["pipeline_flux_control"] = ["FluxControlPipeline"] - _import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"] - _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] - _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"] - _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"] - _import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"] - _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] - _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] - _import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"] + _import_structure["pipeline_omnigen"] = ["OmniGenPipeline"] + + if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + from ...utils.dummy_torch_and_transformers_objects import * else: - from .modeling_flux import ReduxImageEncoder - from .pipeline_flux import FluxPipeline - from .pipeline_flux_control import FluxControlPipeline - from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline - from .pipeline_flux_controlnet import FluxControlNetPipeline - from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline - from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline - from .pipeline_flux_fill import FluxFillPipeline - from .pipeline_flux_img2img import FluxImg2ImgPipeline - from .pipeline_flux_inpaint import FluxInpaintPipeline - from .pipeline_flux_prior_redux import FluxPriorReduxPipeline + from .pipeline_omnigen import OmniGenPipeline + + else: import sys @@ -63,5 +48,3 @@ for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) - for name, value in _additional_imports.items(): - setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/omnigen/kvcache_omnigen.py b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py new file mode 100644 index 000000000000..0270292c130f --- /dev/null +++ b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py @@ -0,0 +1,106 @@ +from typing import Optional, Dict, Any, Tuple, List + +import torch +from transformers.cache_utils import DynamicCache + + +class OmniGenCache(DynamicCache): + def __init__(self, + num_tokens_for_img: int, offload_kv_cache: bool = False) -> None: + if not torch.cuda.is_available(): + raise RuntimeError( + "OmniGenCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!") + super().__init__() + self.original_device = [] + self.prefetch_stream = torch.cuda.Stream() + self.num_tokens_for_img = num_tokens_for_img + self.offload_kv_cache = offload_kv_cache + + def prefetch_layer(self, layer_idx: int): + "Starts prefetching the next layer cache" + if layer_idx < len(self): + with torch.cuda.stream(self.prefetch_stream): + # Prefetch next layer tensors to GPU + device = self.original_device[layer_idx] + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True) + + def evict_previous_layer(self, layer_idx: int): + "Moves the previous layer cache to the CPU" + if len(self) > 2: + # We do it on the default stream so it occurs after all earlier computations on these tensors are done + if layer_idx == 0: + prev_layer_idx = -1 + else: + prev_layer_idx = (layer_idx - 1) % len(self) + self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True) + self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True) + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." + if layer_idx < len(self): + if self.offload_kv_cache: + # Evict the previous layer if necessary + torch.cuda.current_stream().synchronize() + self.evict_previous_layer(layer_idx) + # Load current layer cache to its original device if not already there + # original_device = self.original_device[layer_idx] + # self.prefetch_stream.synchronize(original_device) + self.prefetch_stream.synchronize() + key_tensor = self.key_cache[layer_idx] + value_tensor = self.value_cache[layer_idx] + + # Prefetch the next layer + self.prefetch_layer((layer_idx + 1) % len(self)) + else: + key_tensor = self.key_cache[layer_idx] + value_tensor = self.value_cache[layer_idx] + return (key_tensor, value_tensor) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`. + Return: + A tuple containing the updated key and value states. + """ + # Update the cache + if len(self.key_cache) < layer_idx: + raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") + elif len(self.key_cache) == layer_idx: + # only cache the states for condition tokens + key_states = key_states[..., :-(self.num_tokens_for_img + 1), :] + value_states = value_states[..., :-(self.num_tokens_for_img + 1), :] + + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + self.key_cache.append(key_states) + self.value_cache.append(value_states) + self.original_device.append(key_states.device) + if self.offload_kv_cache: + self.evict_previous_layer(layer_idx) + return self.key_cache[layer_idx], self.value_cache[layer_idx] + else: + # only cache the states for condition tokens + key_tensor, value_tensor = self[layer_idx] + k = torch.cat([key_tensor, key_states], dim=-2) + v = torch.cat([value_tensor, value_states], dim=-2) + return k, v diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index f6f3752fe97d..9c9a67e53e9c 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -15,27 +15,23 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import torch -from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast +from transformers import LlamaTokenizer from ...image_processor import VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL -from ...models.transformers import FluxTransformer2DModel +from ...models.transformers import OmniGenTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import OmniGenPipelineOutput - +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .processor_omnigen import OmniGenMultiModalProcessor +from .kvcache_omnigen import OmniGenCache if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -136,24 +132,15 @@ class OmniGenPipeline( Reference: https://arxiv.org/pdf/2409.11340 Args: - transformer ([`FluxTransformer2DModel`]): - Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + transformer ([`OmniGenTransformer2DModel`]): + Autoregressive Transformer architecture for OmniGen. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - text_encoder_2 ([`T5EncoderModel`]): - [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically - the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). - tokenizer_2 (`T5TokenizerFast`): - Second Tokenizer of class - [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + tokenizer (`LlamaTokenizer`): + Text tokenizer of class. + [LlamaTokenizer](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaTokenizer). """ model_cpu_offload_seq = "transformer->vae" @@ -162,216 +149,110 @@ class OmniGenPipeline( def __init__( self, + transformer: OmniGenTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, - tokenizer: CLIPTokenizer, - transformer: FluxTransformer2DModel, + tokenizer: LlamaTokenizer, ): super().__init__() self.register_modules( vae=vae, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, tokenizer=tokenizer, - tokenizer_2=tokenizer_2, transformer=transformer, scheduler=scheduler, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) - # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # OmniGen latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + self.multimodal_processor = OmniGenMultiModalProcessor(tokenizer, max_image_size=1024) self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 120000 ) self.default_sample_size = 128 - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 512, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) - - text_inputs = self.tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] - - dtype = self.text_encoder_2.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) - - # Use pooled output of CLIPTextModel - prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds - - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - max_sequence_length: int = 512, - lora_scale: Optional[float] = None, + def encod_input_iamges( + self, + input_pixel_values: List[torch.Tensor], + device: Optional[torch.device] = None, ): - r""" - + """ + get the continues embedding of input images by VAE Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + input_pixel_values: normlized pixel of input images + device: + Returns: torch.Tensor """ device = device or self._execution_device + dtype = self.vae.dtype + + input_img_latents = [] + for img in input_pixel_values: + img = self.vae.encode(img.to(device, dtype)).latent_dist.sample().mul_(self.vae.config.scaling_factor) + input_img_latents.append(img) + return input_img_latents + + def get_multimodal_embeddings(self, + input_ids: torch.Tensor, + input_img_latents: List[torch.Tensor], + input_image_sizes: Dict, + device: Optional[torch.device] = None, + ): + """ + get the multi-modal conditional embeddings + Args: + input_ids: a sequence of text id + input_img_latents: continues embedding of input images + input_image_sizes: the index of the input image in the input_ids sequence. + device: - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - # We only use the pooled prompt output from the CLIPTextModel - pooled_prompt_embeds = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - ) - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) + Returns: torch.Tensor - if self.text_encoder is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + """ + device = device or self._execution_device - if self.text_encoder_2 is not None: - if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) + condition_tokens = None + if input_ids is not None: + condition_tokens = self.transformer.llm.embed_tokens(input_ids.to(device)) + input_img_inx = 0 + if input_img_latents is not None: + input_image_tokens = self.transformer.patch_embedding(input_img_latents, + is_input_images=True) - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + for b_inx in input_image_sizes.keys(): + for start_inx, end_inx in input_image_sizes[b_inx]: + # replace the placeholder in text tokens with the image embedding. + condition_tokens[b_inx, start_inx: end_inx] = input_image_tokens[input_img_inx].to( + condition_tokens.dtype) + input_img_inx += 1 - return prompt_embeds, pooled_prompt_embeds, text_ids + return condition_tokens def check_inputs( self, prompt, - prompt_2, + input_images, height, width, - prompt_embeds=None, - pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): + + if input_images is not None: + if len(input_images) != len(prompt): + raise ValueError( + f"The number of prompts: {len(prompt)} does not match the number of input images: {len(input_images)}." + ) + for i in range(len(input_images)): + if not all(f"<|image_{k}|>" in prompt[i] for k in range(len(input_images[i]))): + raise ValueError( + f"prompt `{prompt[i]}` doesn't have enough placeholders for the input images `{input_images[i]}`" + ) + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: logger.warning( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" @@ -384,33 +265,6 @@ def check_inputs( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if max_sequence_length is not None and max_sequence_length > 512: - raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - @staticmethod def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = torch.zeros(height, width, 3) @@ -517,10 +371,6 @@ def prepare_latents( def guidance_scale(self): return self._guidance_scale - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - @property def num_timesteps(self): return self._num_timesteps @@ -533,35 +383,38 @@ def interrupt(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, + prompt: Union[str, List[str]], + input_images: Optional[Union[List[str], List[List[str]]]] = None, height: Optional[int] = None, width: Optional[int] = None, - num_inference_steps: int = 28, + num_inference_steps: int = 50, + max_input_image_size: int = 1024, timesteps: List[int] = None, - guidance_scale: float = 3.5, + guidance_scale: float = 2.5, + img_guidance_scale: float = 1.6, + use_kv_cache: bool = True, + offload_kv_cache: bool = True, + offload_transformer_block: bool = False, + use_input_image_size_as_output: bool = False, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 512, + max_sequence_length: int = 120000, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - will be used instead + The prompt or prompts to guide the image generation. + If the input includes images, need to add placeholders `<|image_i|>` in the prompt to indicate the position of the i-th images. + input_images (`List[str]` or `List[List[str]]`, *optional*): + The list of input images. We will replace the "<|image_i|>" in prompt with the i-th image in list. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -569,16 +422,28 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + max_input_image_size (`int`, *optional*, defaults to 1024): + the maximum size of input image, which will be used to crop the input image to the maximum size timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 7.0): + guidance_scale (`float`, *optional*, defaults to 2.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + img_guidance_scale (`float`, *optional*, defaults to 1.6): + Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800). + use_kv_cache (`bool`, *optional*, defaults to True): + enable kv cache to speed up the inference + offload_kv_cache (`bool`, *optional*, defaults to True): + offload the cached key and value to cpu, which can save memory but slow down the generation silightly + offload_transformer_block (`bool`, *optional*, defaults to False): + offload the transformer layers to cpu, which can save memory but slow down the generation + use_input_image_size_as_output (bool, defaults to False): + whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -588,18 +453,12 @@ def __call__( Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. - joint_attention_kwargs (`dict`, *optional*): + attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). @@ -616,123 +475,116 @@ def __call__( Examples: - Returns: - [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` - is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated - images. + Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images. """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + if isinstance(prompt, str): + prompt = [prompt] + input_images = [input_images] # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, - prompt_2, + input_images, height, width, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale - self._joint_attention_kwargs = joint_attention_kwargs + self._attention_kwargs = attention_kwargs self._interrupt = False # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - + batch_size = len(prompt) device = self._execution_device - lora_scale = ( - self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None - ) - ( - prompt_embeds, - pooled_prompt_embeds, - text_ids, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - device=device, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, + # 3. process multi-modal instructions + if max_input_image_size != self.multimodal_processor.max_image_size: + self.processor = OmniGenMultiModalProcessor(self.text_tokenizer, max_image_size=max_input_image_size) + processed_data = self.processor(prompt, + input_images, + height=height, + width=width, + use_input_image_size_as_output=use_input_image_size_as_output) + + # 4. Encode input images and obtain multi-modal conditional embeddings + input_img_latents = self.encod_input_iamges(processed_data['input_pixel_values'], device=device) + condition_tokens = self.get_multimodal_embeddings(input_ids=processed_data['input_ids'], + input_img_latents=input_img_latents, + input_image_sizes=processed_data['input_image_sizes'], + device=device, + ) + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, ) - # 4. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 - latents, latent_image_ids = self.prepare_latents( + # 6. Prepare latents. + if use_input_image_size_as_output: + height, width = processed_data['input_pixel_values'][0].shape[-2:] + num_cfg = 2 if input_images is not None else 1 + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( batch_size * num_images_per_prompt, - num_channels_latents, + latent_channels, height, width, - prompt_embeds.dtype, + condition_tokens.dtype, device, generator, latents, ) - # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = latents.shape[1] - mu = calculate_shift( - image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, - ) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, - num_inference_steps, - device, - timesteps, - sigmas, - mu=mu, - ) - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) - - # handle guidance - if self.transformer.config.guidance_embeds: - guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None + # 7. Prepare OmniGenCache + num_tokens_for_output_img = latents.size(-1) * latents.size(-2) // (self.patch_size ** 2) + cache = OmniGenCache(num_tokens_for_output_img, offload_kv_cache) if use_kv_cache else None + self.transformer.llm.use_cache = use_kv_cache - # 6. Denoising loop + # 8. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - if self.interrupt: - continue + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (num_cfg+1)) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) - - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred, cache = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + condition_tokens=condition_tokens, + attention_mask=processed_data['attention_mask'], + position_ids=processed_data['position_ids'], + attention_kwargs=attention_kwargs, + past_key_values=cache, + offload_transformer_block=offload_transformer_block, + return_past_key_values=True, return_dict=False, - )[0] + ) + + if use_kv_cache: + if condition_tokens is not None: + condition_tokens = None + processed_data['attention_mask'] = processed_data['attention_mask'][..., -(num_tokens_for_output_img+1):, :] + processed_data['position_ids'] = processed_data['position_ids'][:, -(num_tokens_for_output_img + 1):] + + if num_cfg == 2: + cond, uncond, img_cond = torch.split(noise_pred, len(model_out) // 3, dim=0) + noise_pred = uncond + img_guidance_scale * (img_cond - uncond) + guidance_scale * (cond - img_cond) + else: + cond, uncond = torch.split(noise_pred, len(noise_pred) // 2, dim=0) + noise_pred = uncond + guidance_scale * (cond - uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype + noise_pred = -noise_pred # OmniGen uses standard rectified flow instead of denoise, different from FLUX and SD3 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: @@ -740,30 +592,14 @@ def __call__( # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + progress_bar.update() - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - if output_type == "latent": - image = latents - - else: - latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + if not output_type == "latent": + latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) + else: + image = latents # Offload all models self.maybe_free_model_hooks() @@ -771,4 +607,5 @@ def __call__( if not return_dict: return (image,) - return FluxPipelineOutput(images=image) + return ImagePipelineOutput(images=image) + diff --git a/src/diffusers/pipelines/omnigen/pipeline_output.py b/src/diffusers/pipelines/omnigen/pipeline_output.py deleted file mode 100644 index 40ff24199900..000000000000 --- a/src/diffusers/pipelines/omnigen/pipeline_output.py +++ /dev/null @@ -1,24 +0,0 @@ -from dataclasses import dataclass -from typing import List, Union - -import numpy as np -import PIL.Image - -from ...utils import BaseOutput - - -@dataclass -class OmniGenPipelineOutput(BaseOutput): - """ - Output class for Stable Diffusion pipelines. - - Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - """ - - images: Union[List[PIL.Image.Image], np.ndarray] - - - diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py new file mode 100644 index 000000000000..e6a36bcd8df7 --- /dev/null +++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py @@ -0,0 +1,295 @@ +import re +from typing import Dict, List + +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + + + +def crop_image(pil_image, max_image_size): + """ + Crop the image so that its height and width does not exceed `max_image_size`, + while ensuring both the height and width are multiples of 16. + """ + while min(*pil_image.size) >= 2 * max_image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + if max(*pil_image.size) > max_image_size: + scale = max_image_size / max(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + if min(*pil_image.size) < 16: + scale = 16 / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y1 = (arr.shape[0] % 16) // 2 + crop_y2 = arr.shape[0] % 16 - crop_y1 + + crop_x1 = (arr.shape[1] % 16) // 2 + crop_x2 = arr.shape[1] % 16 - crop_x1 + + arr = arr[crop_y1:arr.shape[0] - crop_y2, crop_x1:arr.shape[1] - crop_x2] + return Image.fromarray(arr) + + +class OmniGenMultiModalProcessor: + def __init__(self, + text_tokenizer, + max_image_size: int = 1024): + self.text_tokenizer = text_tokenizer + self.max_image_size = max_image_size + + self.image_transform = transforms.Compose([ + transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + + self.collator = OmniGenCollator() + + def process_image(self, image): + image = Image.open(image).convert('RGB') + return self.image_transform(image) + + def process_multi_modal_prompt(self, text, input_images): + text = self.add_prefix_instruction(text) + if input_images is None or len(input_images) == 0: + model_inputs = self.text_tokenizer(text) + return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None} + + pattern = r"<\|image_\d+\|>" + prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)] + + for i in range(1, len(prompt_chunks)): + if prompt_chunks[i][0] == 1: + prompt_chunks[i] = prompt_chunks[i][1:] + + image_tags = re.findall(pattern, text) + image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags] + + unique_image_ids = sorted(list(set(image_ids))) + assert unique_image_ids == list(range(1, + len(unique_image_ids) + 1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}" + # total images must be the same as the number of image tags + assert len(unique_image_ids) == len( + input_images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images" + + input_images = [input_images[x - 1] for x in image_ids] + + all_input_ids = [] + img_inx = [] + idx = 0 + for i in range(len(prompt_chunks)): + all_input_ids.extend(prompt_chunks[i]) + if i != len(prompt_chunks) - 1: + start_inx = len(all_input_ids) + size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16 + img_inx.append([start_inx, start_inx + size]) + all_input_ids.extend([0] * size) + + return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx} + + def add_prefix_instruction(self, prompt): + user_prompt = '<|user|>\n' + generation_prompt = 'Generate an image according to the following instructions\n' + assistant_prompt = '<|assistant|>\n<|diffusion|>' + prompt_suffix = "<|end|>\n" + prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}" + return prompt + + def __call__(self, + instructions: List[str], + input_images: List[List[str]] = None, + height: int = 1024, + width: int = 1024, + negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.", + use_img_cfg: bool = True, + separate_cfg_input: bool = False, + use_input_image_size_as_output: bool = False, + ) -> Dict: + + if input_images is None: + use_img_cfg = False + if isinstance(instructions, str): + instructions = [instructions] + input_images = [input_images] + + input_data = [] + for i in range(len(instructions)): + cur_instruction = instructions[i] + cur_input_images = None if input_images is None else input_images[i] + if cur_input_images is not None and len(cur_input_images) > 0: + cur_input_images = [self.process_image(x) for x in cur_input_images] + else: + cur_input_images = None + assert "<|image_1|>" not in cur_instruction + + mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images) + + neg_mllm_input, img_cfg_mllm_input = None, None + neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None) + if use_img_cfg: + if cur_input_images is not None and len(cur_input_images) >= 1: + img_cfg_prompt = [f"<|image_{i + 1}|>" for i in range(len(cur_input_images))] + img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images) + else: + img_cfg_mllm_input = neg_mllm_input + + if use_input_image_size_as_output: + input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, + [mllm_input['pixel_values'][0].size(-2), mllm_input['pixel_values'][0].size(-1)])) + else: + input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width])) + + return self.collator(input_data) + + +class OmniGenCollator: + def __init__(self, pad_token_id=2, hidden_size=3072): + self.pad_token_id = pad_token_id + self.hidden_size = hidden_size + + def create_position(self, attention_mask, num_tokens_for_output_images): + position_ids = [] + text_length = attention_mask.size(-1) + img_length = max(num_tokens_for_output_images) + for mask in attention_mask: + temp_l = torch.sum(mask) + temp_position = [0] * (text_length - temp_l) + [i for i in range( + temp_l + img_length + 1)] # we add a time embedding into the sequence, so add one more token + position_ids.append(temp_position) + return torch.LongTensor(position_ids) + + def create_mask(self, attention_mask, num_tokens_for_output_images): + """ + OmniGen applies causal attention to each element in the sequence, but applies bidirectional attention within each image sequence + References: [OmniGen](https://arxiv.org/pdf/2409.11340) + """ + extended_mask = [] + padding_images = [] + text_length = attention_mask.size(-1) + img_length = max(num_tokens_for_output_images) + seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token + inx = 0 + for mask in attention_mask: + temp_l = torch.sum(mask) + pad_l = text_length - temp_l + + temp_mask = torch.tril(torch.ones(size=(temp_l + 1, temp_l + 1))) + + image_mask = torch.zeros(size=(temp_l + 1, img_length)) + temp_mask = torch.cat([temp_mask, image_mask], dim=-1) + + image_mask = torch.ones(size=(img_length, temp_l + img_length + 1)) + temp_mask = torch.cat([temp_mask, image_mask], dim=0) + + if pad_l > 0: + pad_mask = torch.zeros(size=(temp_l + 1 + img_length, pad_l)) + temp_mask = torch.cat([pad_mask, temp_mask], dim=-1) + + pad_mask = torch.ones(size=(pad_l, seq_len)) + temp_mask = torch.cat([pad_mask, temp_mask], dim=0) + + true_img_length = num_tokens_for_output_images[inx] + pad_img_length = img_length - true_img_length + if pad_img_length > 0: + temp_mask[:, -pad_img_length:] = 0 + temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size)) + else: + temp_padding_imgs = None + + extended_mask.append(temp_mask.unsqueeze(0)) + padding_images.append(temp_padding_imgs) + inx += 1 + return torch.cat(extended_mask, dim=0), padding_images + + def adjust_attention_for_input_images(self, attention_mask, image_sizes): + for b_inx in image_sizes.keys(): + for start_inx, end_inx in image_sizes[b_inx]: + attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1 + + return attention_mask + + def pad_input_ids(self, input_ids, image_sizes): + max_l = max([len(x) for x in input_ids]) + padded_ids = [] + attention_mask = [] + + for i in range(len(input_ids)): + temp_ids = input_ids[i] + temp_l = len(temp_ids) + pad_l = max_l - temp_l + if pad_l == 0: + attention_mask.append([1] * max_l) + padded_ids.append(temp_ids) + else: + attention_mask.append([0] * pad_l + [1] * temp_l) + padded_ids.append([self.pad_token_id] * pad_l + temp_ids) + + if i in image_sizes: + new_inx = [] + for old_inx in image_sizes[i]: + new_inx.append([x + pad_l for x in old_inx]) + image_sizes[i] = new_inx + + return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes + + def process_mllm_input(self, mllm_inputs, target_img_size): + num_tokens_for_output_images = [] + for img_size in target_img_size: + num_tokens_for_output_images.append(img_size[0] * img_size[1] // 16 // 16) + + pixel_values, image_sizes = [], {} + b_inx = 0 + for x in mllm_inputs: + if x['pixel_values'] is not None: + pixel_values.extend(x['pixel_values']) + for size in x['image_sizes']: + if b_inx not in image_sizes: + image_sizes[b_inx] = [size] + else: + image_sizes[b_inx].append(size) + b_inx += 1 + pixel_values = [x.unsqueeze(0) for x in pixel_values] + + input_ids = [x['input_ids'] for x in mllm_inputs] + padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes) + position_ids = self.create_position(attention_mask, num_tokens_for_output_images) + attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images) + attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes) + + return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes + + def __call__(self, features): + mllm_inputs = [f[0] for f in features] + cfg_mllm_inputs = [f[1] for f in features] + img_cfg_mllm_input = [f[2] for f in features] + target_img_size = [f[3] for f in features] + + if img_cfg_mllm_input[0] is not None: + mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input + target_img_size = target_img_size + target_img_size + target_img_size + else: + mllm_inputs = mllm_inputs + cfg_mllm_inputs + target_img_size = target_img_size + target_img_size + + all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input( + mllm_inputs, target_img_size) + + data = {"input_ids": all_padded_input_ids, + "attention_mask": all_attention_mask, + "position_ids": all_position_ids, + "input_pixel_values": all_pixel_values, + "input_image_sizes": all_image_sizes, + } + return data + diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 5091ff318f1b..ce7da1bf8301 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -466,6 +466,20 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class OmniGenTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + class PixArtTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index b76ea3824060..b1440d33a6f7 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1142,6 +1142,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class OmniGenPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class PaintByExamplePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/test.py b/test.py index e110b93a7b74..3648f7e9875e 100644 --- a/test.py +++ b/test.py @@ -3,7 +3,7 @@ from huggingface_hub import snapshot_download -from diffusers.models import OmniGenTransformerModel +from diffusers.models import OmniGenTransformer2DModel from transformers import Phi3Model, Phi3Config @@ -11,7 +11,7 @@ model_name = "Shitao/OmniGen-v1" config = Phi3Config.from_pretrained("Shitao/OmniGen-v1") -model = OmniGenTransformerModel(transformer_config=config) +model = OmniGenTransformer2DModel(transformer_config=config) cache_folder = os.getenv('HF_HUB_CACHE') model_name = snapshot_download(repo_id=model_name, cache_dir=cache_folder, From 0d041944dcd766db468e0568f8e12c02a64964ac Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Mon, 2 Dec 2024 17:40:43 +0800 Subject: [PATCH 04/55] omnigen pipeline --- scripts/convert_omnigen_to_diffusers.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/scripts/convert_omnigen_to_diffusers.py b/scripts/convert_omnigen_to_diffusers.py index 2b02f6a064dc..5d6d98bcb9a7 100644 --- a/scripts/convert_omnigen_to_diffusers.py +++ b/scripts/convert_omnigen_to_diffusers.py @@ -4,13 +4,25 @@ import torch from safetensors.torch import load_file from transformers import AutoModel, AutoTokenizer, AutoConfig +from huggingface_hub import snapshot_download from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenTransformer2DModel, OmniGenPipeline def main(args): # checkpoint from https://huggingface.co/Shitao/OmniGen-v1 - ckpt = load_file(args.origin_ckpt_path, device="cpu") + + if not os.path.exists(args.origin_ckpt_path): + print("Model not found, downloading...") + cache_folder = os.getenv('HF_HUB_CACHE') + args.origin_ckpt_path = snapshot_download(repo_id=args.origin_ckpt_path, + cache_dir=cache_folder, + ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', + 'model.pt']) + print(f"Downloaded model to {args.origin_ckpt_path}") + + ckpt = os.path.join(args.origin_ckpt_path, 'model.safetensors') + ckpt = load_file(ckpt, device="cpu") mapping_dict = { "pos_embed": "patch_embedding.pos_embed", @@ -27,7 +39,6 @@ def main(args): converted_state_dict = {} for k, v in ckpt.items(): - # new_ckpt[k] = v if k in mapping_dict: converted_state_dict[mapping_dict[k]] = v else: @@ -35,7 +46,6 @@ def main(args): transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path) - # Lumina-Next-SFT 2B transformer = OmniGenTransformer2DModel( transformer_config=transformer_config, patch_size=2, @@ -49,7 +59,7 @@ def main(args): scheduler = FlowMatchEulerDiscreteScheduler() - vae = AutoencoderKL.from_pretrained(args.origin_ckpt_path, torch_dtype=torch.float32) + vae = AutoencoderKL.from_pretrained(os.path.join(args.origin_ckpt_path, "vae"), torch_dtype=torch.float32) tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path) @@ -64,10 +74,10 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument( - "--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + "--origin_ckpt_path", default="Shitao/OmniGen-v1", type=str, required=False, help="Path to the checkpoint to convert." ) - parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") + parser.add_argument("--dump_path", default="OmniGen-v1-diffusers", type=str, required=True, help="Path to the output pipeline.") args = parser.parse_args() main(args) From 85abe5e5b077d7c49f220d0a78317cc413a6bf5c Mon Sep 17 00:00:00 2001 From: staoxiao <2906698981@qq.com> Date: Tue, 3 Dec 2024 17:15:07 +0800 Subject: [PATCH 05/55] update omnigen_pipeline --- scripts/convert_omnigen_to_diffusers.py | 141 ++++++++- src/diffusers/models/embeddings.py | 4 +- .../transformers/transformer_omnigen.py | 9 +- src/diffusers/pipelines/__init__.py | 2 + .../pipelines/omnigen/kvcache_omnigen.py | 45 +-- .../pipelines/omnigen/pipeline_omnigen.py | 108 +++---- .../pipelines/omnigen/processor_omnigen.py | 2 - .../scheduling_flow_match_euler_discrete.py | 4 +- test.py | 92 +++--- tests/pipelines/omnigen/__init__.py | 0 tests/pipelines/omnigen/test_pipeline_flux.py | 298 ++++++++++++++++++ 11 files changed, 570 insertions(+), 135 deletions(-) create mode 100644 tests/pipelines/omnigen/__init__.py create mode 100644 tests/pipelines/omnigen/test_pipeline_flux.py diff --git a/scripts/convert_omnigen_to_diffusers.py b/scripts/convert_omnigen_to_diffusers.py index 5d6d98bcb9a7..8c9bc8bdb457 100644 --- a/scripts/convert_omnigen_to_diffusers.py +++ b/scripts/convert_omnigen_to_diffusers.py @@ -1,5 +1,6 @@ import argparse import os +os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2' import torch from safetensors.torch import load_file @@ -44,8 +45,141 @@ def main(args): else: converted_state_dict[k] = v - transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path) - + # transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path) + # print(type(transformer_config.__dict__)) + # print(transformer_config.__dict__) + + transformer_config = { + "_name_or_path": "Phi-3-vision-128k-instruct", + "architectures": [ + "Phi3ForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 3072, + "initializer_range": 0.02, + "intermediate_size": 8192, + "max_position_embeddings": 131072, + "model_type": "phi3", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "original_max_position_embeddings": 4096, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "long_factor": [ + 1.0299999713897705, + 1.0499999523162842, + 1.0499999523162842, + 1.0799999237060547, + 1.2299998998641968, + 1.2299998998641968, + 1.2999999523162842, + 1.4499999284744263, + 1.5999999046325684, + 1.6499998569488525, + 1.8999998569488525, + 2.859999895095825, + 3.68999981880188, + 5.419999599456787, + 5.489999771118164, + 5.489999771118164, + 9.09000015258789, + 11.579999923706055, + 15.65999984741211, + 15.769999504089355, + 15.789999961853027, + 18.360000610351562, + 21.989999771118164, + 23.079999923706055, + 30.009998321533203, + 32.35000228881836, + 32.590003967285156, + 35.56000518798828, + 39.95000457763672, + 53.840003967285156, + 56.20000457763672, + 57.95000457763672, + 59.29000473022461, + 59.77000427246094, + 59.920005798339844, + 61.190006256103516, + 61.96000671386719, + 62.50000762939453, + 63.3700065612793, + 63.48000717163086, + 63.48000717163086, + 63.66000747680664, + 63.850006103515625, + 64.08000946044922, + 64.760009765625, + 64.80001068115234, + 64.81001281738281, + 64.81001281738281 + ], + "short_factor": [ + 1.05, + 1.05, + 1.05, + 1.1, + 1.1, + 1.1, + 1.2500000000000002, + 1.2500000000000002, + 1.4000000000000004, + 1.4500000000000004, + 1.5500000000000005, + 1.8500000000000008, + 1.9000000000000008, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.1000000000000005, + 2.1000000000000005, + 2.2, + 2.3499999999999996, + 2.3499999999999996, + 2.3499999999999996, + 2.3499999999999996, + 2.3999999999999995, + 2.3999999999999995, + 2.6499999999999986, + 2.6999999999999984, + 2.8999999999999977, + 2.9499999999999975, + 3.049999999999997, + 3.049999999999997, + 3.049999999999997 + ], + "type": "su" + }, + "rope_theta": 10000.0, + "sliding_window": 131072, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.1", + "use_cache": True, + "vocab_size": 32064, + "_attn_implementation": "sdpa" + } transformer = OmniGenTransformer2DModel( transformer_config=transformer_config, patch_size=2, @@ -53,6 +187,7 @@ def main(args): pos_embed_max_size=192, ) transformer.load_state_dict(converted_state_dict, strict=True) + transformer.to(torch.bfloat16) num_model_params = sum(p.numel() for p in transformer.parameters()) print(f"Total number of transformer parameters: {num_model_params}") @@ -77,7 +212,7 @@ def main(args): "--origin_ckpt_path", default="Shitao/OmniGen-v1", type=str, required=False, help="Path to the checkpoint to convert." ) - parser.add_argument("--dump_path", default="OmniGen-v1-diffusers", type=str, required=True, help="Path to the output pipeline.") + parser.add_argument("--dump_path", default="/share/shitao/repos/OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline.") args = parser.parse_args() main(args) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 9681a84b8878..720a48f3f747 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -381,9 +381,9 @@ def forward(self, height, width = latent.shape[-2:] pos_embed = self.cropped_pos_embed(height, width) latent = self.patch_embeddings(latent, is_input_image) - latent = latent + pos_embed + patched_latents = latent + pos_embed - return latent + return patched_latents class LuminaPatchEmbed(nn.Module): diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index 2fa97146f1d3..031fc014e77a 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -125,7 +125,7 @@ def forward( ) use_cache = False - # kept for BC (non `Cache` `past_key_values` inputs) + # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): return_legacy_cache = True @@ -240,7 +240,7 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): @register_to_config def __init__( self, - transformer_config: Phi3Config, + transformer_config: Dict, patch_size=2, in_channels=4, pos_embed_max_size: int = 192, @@ -251,6 +251,7 @@ def __init__( self.patch_size = patch_size self.pos_embed_max_size = pos_embed_max_size + transformer_config = Phi3Config(**transformer_config) hidden_size = transformer_config.hidden_size self.patch_embedding = OmniGenPatchEmbed(patch_size=patch_size, @@ -386,7 +387,7 @@ def forward(self, "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) - height, width = hidden_states.size(-2) + height, width = hidden_states.size()[-2:] hidden_states = self.patch_embedding(hidden_states, is_input_image=False) num_tokens_for_output_image = hidden_states.size(1) @@ -405,7 +406,7 @@ def forward(self, image_embedding = output[:, -num_tokens_for_output_image:] time_emb = self.t_embedder(timestep, dtype=hidden_states.dtype) - x = self.final_layer(image_embedding, time_emb) + x = self.proj_out(self.norm_out(image_embedding, temb=time_emb)) output = self.unpatchify(x, height, width) if not return_dict: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 5143b1114fd3..0b390444f5f2 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -254,6 +254,7 @@ ) _import_structure["mochi"] = ["MochiPipeline"] _import_structure["musicldm"] = ["MusicLDMPipeline"] + _import_structure["omnigen"] = ["OmniGenPipeline"] _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] _import_structure["pia"] = ["PIAPipeline"] _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"] @@ -584,6 +585,7 @@ ) from .mochi import MochiPipeline from .musicldm import MusicLDMPipeline + from .omnigen import OmniGenPipeline from .pag import ( AnimateDiffPAGPipeline, HunyuanDiTPAGPipeline, diff --git a/src/diffusers/pipelines/omnigen/kvcache_omnigen.py b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py index 0270292c130f..7f02588ce405 100644 --- a/src/diffusers/pipelines/omnigen/kvcache_omnigen.py +++ b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py @@ -1,15 +1,20 @@ +from tqdm import tqdm from typing import Optional, Dict, Any, Tuple, List +import gc import torch -from transformers.cache_utils import DynamicCache +from transformers.cache_utils import Cache, DynamicCache, OffloadedCache + class OmniGenCache(DynamicCache): - def __init__(self, - num_tokens_for_img: int, offload_kv_cache: bool = False) -> None: + def __init__(self, + num_tokens_for_img: int, + offload_kv_cache: bool=False) -> None: if not torch.cuda.is_available(): - raise RuntimeError( - "OmniGenCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!") + # print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!") + # offload_kv_cache = False + raise RuntimeError("OffloadedCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!") super().__init__() self.original_device = [] self.prefetch_stream = torch.cuda.Stream() @@ -25,17 +30,19 @@ def prefetch_layer(self, layer_idx: int): self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True) + def evict_previous_layer(self, layer_idx: int): "Moves the previous layer cache to the CPU" if len(self) > 2: # We do it on the default stream so it occurs after all earlier computations on these tensors are done - if layer_idx == 0: + if layer_idx == 0: prev_layer_idx = -1 else: prev_layer_idx = (layer_idx - 1) % len(self) self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True) self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True) + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." if layer_idx < len(self): @@ -44,12 +51,12 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: torch.cuda.current_stream().synchronize() self.evict_previous_layer(layer_idx) # Load current layer cache to its original device if not already there - # original_device = self.original_device[layer_idx] + original_device = self.original_device[layer_idx] # self.prefetch_stream.synchronize(original_device) - self.prefetch_stream.synchronize() + torch.cuda.synchronize(self.prefetch_stream) key_tensor = self.key_cache[layer_idx] value_tensor = self.value_cache[layer_idx] - + # Prefetch the next layer self.prefetch_layer((layer_idx + 1) % len(self)) else: @@ -58,13 +65,13 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: return (key_tensor, value_tensor) else: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") - + def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. @@ -85,13 +92,13 @@ def update( raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") elif len(self.key_cache) == layer_idx: # only cache the states for condition tokens - key_states = key_states[..., :-(self.num_tokens_for_img + 1), :] - value_states = value_states[..., :-(self.num_tokens_for_img + 1), :] + key_states = key_states[..., :-(self.num_tokens_for_img+1), :] + value_states = value_states[..., :-(self.num_tokens_for_img+1), :] - # Update the number of seen tokens + # Update the number of seen tokens if layer_idx == 0: self._seen_tokens += key_states.shape[-2] - + self.key_cache.append(key_states) self.value_cache.append(value_states) self.original_device.append(key_states.device) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 9c9a67e53e9c..500a440f4203 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -15,6 +15,7 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union +import numpy as np import torch from transformers import LlamaTokenizer @@ -179,6 +180,7 @@ def encod_input_iamges( self, input_pixel_values: List[torch.Tensor], device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): """ get the continues embedding of input images by VAE @@ -188,7 +190,7 @@ def encod_input_iamges( Returns: torch.Tensor """ device = device or self._execution_device - dtype = self.vae.dtype + dtype = dtype or self.vae.dtype input_img_latents = [] for img in input_pixel_values: @@ -215,13 +217,15 @@ def get_multimodal_embeddings(self, """ device = device or self._execution_device + input_img_latents = [x.to(self.transformer.dtype) for x in input_img_latents] + condition_tokens = None if input_ids is not None: - condition_tokens = self.transformer.llm.embed_tokens(input_ids.to(device)) + condition_tokens = self.transformer.llm.embed_tokens(input_ids.to(device)).clone() input_img_inx = 0 if input_img_latents is not None: input_image_tokens = self.transformer.patch_embedding(input_img_latents, - is_input_images=True) + is_input_image=True) for b_inx in input_image_sizes.keys(): for start_inx, end_inx in input_image_sizes[b_inx]: @@ -248,10 +252,11 @@ def check_inputs( f"The number of prompts: {len(prompt)} does not match the number of input images: {len(input_images)}." ) for i in range(len(input_images)): - if not all(f"<|image_{k}|>" in prompt[i] for k in range(len(input_images[i]))): - raise ValueError( - f"prompt `{prompt[i]}` doesn't have enough placeholders for the input images `{input_images[i]}`" - ) + if input_images[i] is not None: + if not all(f"<|image_{k+1}|>" in prompt[i] for k in range(len(input_images[i]))): + raise ValueError( + f"prompt `{prompt[i]}` doesn't have enough placeholders for the input images `{input_images[i]}`" + ) if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: logger.warning( @@ -279,30 +284,6 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): return latent_image_ids.to(device=device, dtype=dtype) - @staticmethod - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - def _unpack_latents(latents, height, width, vae_scale_factor): - batch_size, num_patches, channels = latents.shape - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to @@ -343,16 +324,15 @@ def prepare_latents( generator, latents=None, ): - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (self.vae_scale_factor * 2)) - width = 2 * (int(width) // (self.vae_scale_factor * 2)) - - shape = (batch_size, num_channels_latents, height, width) - if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -361,11 +341,8 @@ def prepare_latents( ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - return latents, latent_image_ids + return latents @property def guidance_scale(self): @@ -482,9 +459,14 @@ def __call__( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + num_cfg = 2 if input_images is not None else 1 + use_img_cfg = True if input_images is not None else False if isinstance(prompt, str): prompt = [prompt] input_images = [input_images] + + # using Float32 for the VAE doesn't take up much memory but can prevent potential black image outputs. + self.vae.to(torch.float32) # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -506,12 +488,15 @@ def __call__( # 3. process multi-modal instructions if max_input_image_size != self.multimodal_processor.max_image_size: - self.processor = OmniGenMultiModalProcessor(self.text_tokenizer, max_image_size=max_input_image_size) - processed_data = self.processor(prompt, - input_images, - height=height, - width=width, - use_input_image_size_as_output=use_input_image_size_as_output) + self.multimodal_processor = OmniGenMultiModalProcessor(self.text_tokenizer, max_image_size=max_input_image_size) + processed_data = self.multimodal_processor(prompt, + input_images, + height=height, + width=width, + use_img_cfg=use_img_cfg, + use_input_image_size_as_output=use_input_image_size_as_output) + processed_data['attention_mask'] = processed_data['attention_mask'].to(device) + processed_data['position_ids'] = processed_data['position_ids'].to(device) # 4. Encode input images and obtain multi-modal conditional embeddings input_img_latents = self.encod_input_iamges(processed_data['input_pixel_values'], device=device) @@ -522,14 +507,14 @@ def __call__( ) # 5. Prepare timesteps + sigmas = np.linspace(1, 0, num_inference_steps+1)[:num_inference_steps] timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, + self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas ) # 6. Prepare latents. if use_input_image_size_as_output: height, width = processed_data['input_pixel_values'][0].shape[-2:] - num_cfg = 2 if input_images is not None else 1 latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, @@ -542,10 +527,15 @@ def __call__( latents, ) + + generator = torch.Generator(device=device).manual_seed(0) + latents = torch.randn(1, 4, height//8, width//8, device=device, generator=generator).to(self.transformer.dtype) + # latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype) + # 7. Prepare OmniGenCache - num_tokens_for_output_img = latents.size(-1) * latents.size(-2) // (self.patch_size ** 2) + num_tokens_for_output_img = latents.size(-1) * latents.size(-2) // (self.transformer.patch_size ** 2) cache = OmniGenCache(num_tokens_for_output_img, offload_kv_cache) if use_kv_cache else None - self.transformer.llm.use_cache = use_kv_cache + self.transformer.llm.config.use_cache = use_kv_cache # 8. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -565,18 +555,18 @@ def __call__( attention_kwargs=attention_kwargs, past_key_values=cache, offload_transformer_block=offload_transformer_block, - return_past_key_values=True, return_dict=False, ) - + + # if use kv cache, don't need attention mask and position ids of condition tokens for next step if use_kv_cache: if condition_tokens is not None: condition_tokens = None - processed_data['attention_mask'] = processed_data['attention_mask'][..., -(num_tokens_for_output_img+1):, :] + processed_data['attention_mask'] = processed_data['attention_mask'][..., -(num_tokens_for_output_img + 1):, :] # +1 is for the timestep token processed_data['position_ids'] = processed_data['position_ids'][:, -(num_tokens_for_output_img + 1):] if num_cfg == 2: - cond, uncond, img_cond = torch.split(noise_pred, len(model_out) // 3, dim=0) + cond, uncond, img_cond = torch.split(noise_pred, len(noise_pred) // 3, dim=0) noise_pred = uncond + img_guidance_scale * (img_cond - uncond) + guidance_scale * (cond - img_cond) else: cond, uncond = torch.split(noise_pred, len(noise_pred) // 2, dim=0) @@ -584,7 +574,6 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - noise_pred = -noise_pred # OmniGen uses standard rectified flow instead of denoise, different from FLUX and SD3 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: @@ -595,6 +584,7 @@ def __call__( progress_bar.update() if not output_type == "latent": + latents = latents.to(self.vae.dtype) latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py index e6a36bcd8df7..b425d84269df 100644 --- a/src/diffusers/pipelines/omnigen/processor_omnigen.py +++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py @@ -117,8 +117,6 @@ def __call__(self, use_input_image_size_as_output: bool = False, ) -> Dict: - if input_images is None: - use_img_cfg = False if isinstance(instructions, str): instructions = [instructions] input_images = [input_images] diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index c1096dbe0c29..7e01160b8626 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -212,7 +212,7 @@ def set_timesteps( else: sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) - self.timesteps = timesteps.to(device=device) + self.timesteps = timesteps.to(device=device) self.sigmas = sigmas self._step_index = None self._begin_index = None @@ -300,7 +300,7 @@ def step( sigma = self.sigmas[self.step_index] sigma_next = self.sigmas[self.step_index + 1] - + prev_sample = sample + (sigma_next - sigma) * model_output # Cast sample back to model compatible dtype diff --git a/test.py b/test.py index 3648f7e9875e..fe1410122acd 100644 --- a/test.py +++ b/test.py @@ -1,50 +1,54 @@ -import os -os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2' - -from huggingface_hub import snapshot_download - -from diffusers.models import OmniGenTransformer2DModel -from transformers import Phi3Model, Phi3Config - - -from safetensors.torch import load_file - -model_name = "Shitao/OmniGen-v1" -config = Phi3Config.from_pretrained("Shitao/OmniGen-v1") -model = OmniGenTransformer2DModel(transformer_config=config) -cache_folder = os.getenv('HF_HUB_CACHE') -model_name = snapshot_download(repo_id=model_name, - cache_dir=cache_folder, - ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) -print(model_name) -model_path = os.path.join(model_name, 'model.safetensors') -ckpt = load_file(model_path, 'cpu') - - -mapping_dict = { - "pos_embed": "patch_embedding.pos_embed", - "x_embedder.proj.weight": "patch_embedding.output_image_proj.weight", - "x_embedder.proj.bias": "patch_embedding.output_image_proj.bias", - "input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight", - "input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias", - "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight", - "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias", - "final_layer.linear.weight": "proj_out.weight", - "final_layer.linear.bias": "proj_out.bias", - -} - -new_ckpt = {} -for k, v in ckpt.items(): - # new_ckpt[k] = v - if k in mapping_dict: - new_ckpt[mapping_dict[k]] = v - else: - new_ckpt[k] = v +# import os +# os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2' + +# from huggingface_hub import snapshot_download + +# from diffusers.models import OmniGenTransformer2DModel +# from transformers import Phi3Model, Phi3Config + + +# from safetensors.torch import load_file + +# model_name = "Shitao/OmniGen-v1" +# config = Phi3Config.from_pretrained("Shitao/OmniGen-v1") +# model = OmniGenTransformer2DModel(transformer_config=config) +# cache_folder = os.getenv('HF_HUB_CACHE') +# model_name = snapshot_download(repo_id=model_name, +# cache_dir=cache_folder, +# ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) +# print(model_name) +# model_path = os.path.join(model_name, 'model.safetensors') +# ckpt = load_file(model_path, 'cpu') + + +# mapping_dict = { +# "pos_embed": "patch_embedding.pos_embed", +# "x_embedder.proj.weight": "patch_embedding.output_image_proj.weight", +# "x_embedder.proj.bias": "patch_embedding.output_image_proj.bias", +# "input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight", +# "input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias", +# "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight", +# "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias", +# "final_layer.linear.weight": "proj_out.weight", +# "final_layer.linear.bias": "proj_out.bias", + +# } + +# new_ckpt = {} +# for k, v in ckpt.items(): +# # new_ckpt[k] = v +# if k in mapping_dict: +# new_ckpt[mapping_dict[k]] = v +# else: +# new_ckpt[k] = v -model.load_state_dict(new_ckpt) +# model.load_state_dict(new_ckpt) + + + + diff --git a/tests/pipelines/omnigen/__init__.py b/tests/pipelines/omnigen/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/omnigen/test_pipeline_flux.py b/tests/pipelines/omnigen/test_pipeline_flux.py new file mode 100644 index 000000000000..df9021ee0adb --- /dev/null +++ b/tests/pipelines/omnigen/test_pipeline_flux.py @@ -0,0 +1,298 @@ +import gc +import unittest + +import numpy as np +import pytest +import torch +from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel +from diffusers.utils.testing_utils import ( + numpy_cosine_similarity_distance, + require_big_gpu_with_torch_cuda, + slow, + torch_device, +) + +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, +) + + +class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = FluxPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) + + # there is no xformers processor for Flux + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "max_sequence_length": 48, + "output_type": "np", + } + return inputs + + def test_flux_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different here + # For some reasons, they don't show large differences + assert max_diff > 1e-6 + + def test_flux_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( + prompt, + prompt_2=None, + device=torch_device, + max_sequence_length=inputs["max_sequence_length"], + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added + # to the pipeline level. + pipe.transformer.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." + + def test_flux_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) + + +@slow +@require_big_gpu_with_torch_cuda +@pytest.mark.big_gpu_with_torch_cuda +class FluxPipelineSlowTests(unittest.TestCase): + pipeline_class = FluxPipeline + repo_id = "black-forest-labs/FLUX.1-schnell" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + prompt_embeds = torch.load( + hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt") + ) + pooled_prompt_embeds = torch.load( + hf_hub_download( + repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt" + ) + ) + return { + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + "num_inference_steps": 2, + "guidance_scale": 0.0, + "max_sequence_length": 256, + "output_type": "np", + "generator": generator, + } + + def test_flux_inference(self): + pipe = self.pipeline_class.from_pretrained( + self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None + ) + pipe.enable_model_cpu_offload() + + inputs = self.get_inputs(torch_device) + + image = pipe(**inputs).images[0] + image_slice = image[0, :10, :10] + expected_slice = np.array( + [ + 0.3242, + 0.3203, + 0.3164, + 0.3164, + 0.3125, + 0.3125, + 0.3281, + 0.3242, + 0.3203, + 0.3301, + 0.3262, + 0.3242, + 0.3281, + 0.3242, + 0.3203, + 0.3262, + 0.3262, + 0.3164, + 0.3262, + 0.3281, + 0.3184, + 0.3281, + 0.3281, + 0.3203, + 0.3281, + 0.3281, + 0.3164, + 0.3320, + 0.3320, + 0.3203, + ], + dtype=np.float32, + ) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) + + assert max_diff < 1e-4 From db92c6996d63417da5d8dcd4c97a5c16c7167f50 Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Tue, 3 Dec 2024 21:37:16 +0800 Subject: [PATCH 06/55] test case for omnigen --- docs/source/en/api/pipelines/omnigen.md | 100 ++++++ .../pipelines/omnigen/pipeline_omnigen.py | 10 +- .../pipelines/omnigen/processor_omnigen.py | 3 +- tests/pipelines/omnigen/test_pipeline_flux.py | 298 ------------------ .../omnigen/test_pipeline_omnigen.py | 269 ++++++++++++++++ 5 files changed, 374 insertions(+), 306 deletions(-) create mode 100644 docs/source/en/api/pipelines/omnigen.md delete mode 100644 tests/pipelines/omnigen/test_pipeline_flux.py create mode 100644 tests/pipelines/omnigen/test_pipeline_omnigen.py diff --git a/docs/source/en/api/pipelines/omnigen.md b/docs/source/en/api/pipelines/omnigen.md new file mode 100644 index 000000000000..fe0bc5d496a0 --- /dev/null +++ b/docs/source/en/api/pipelines/omnigen.md @@ -0,0 +1,100 @@ + + +# OmniGen + +[OmniGen: Unified Image Generation](https://arxiv.org/pdf/2409.11340) from BAAI, by Shitao Xiao, Yueze Wang, Junjie Zhou, Huaying Yuan, Xingrun Xing, Ruiran Yan, Chaofan Li, Shuting Wang, Tiejun Huang, Zheng Liu. + +The abstract from the paper is: + +*The emergence of Large Language Models (LLMs) has unified language +generation tasks and revolutionized human-machine interaction. +However, in the realm of image generation, a unified model capable of handling various tasks +within a single framework remains largely unexplored. In +this work, we introduce OmniGen, a new diffusion model +for unified image generation. OmniGen is characterized +by the following features: 1) Unification: OmniGen not +only demonstrates text-to-image generation capabilities but +also inherently supports various downstream tasks, such +as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of +OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion +models, it is more user-friendly and can complete complex +tasks end-to-end through instructions without the need for +extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from +learning in a unified format, OmniGen effectively transfers +knowledge across different tasks, manages unseen tasks and +domains, and exhibits novel capabilities. We also explore +the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. This work represents the first attempt at a general-purpose image generation model, and we will release our resources at https: +//github.com/VectorSpaceLab/OmniGen to foster future advancements.* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +This pipeline was contributed by [staoxiao](https://github.com/staoxiao). The original codebase can be found [here](https://github.com/VectorSpaceLab/OmniGen). The original weights can be found under [hf.co/shitao](https://huggingface.co/Shitao/OmniGen-v1). + + +## Inference + +Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency. + +First, load the pipeline: + +```python +import torch +from diffusers import CogVideoXPipeline, CogVideoXImageToVideoPipeline +from diffusers.utils import export_to_video,load_image +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b").to("cuda") # or "THUDM/CogVideoX-2b" +``` + +If you are using the image-to-video pipeline, load it as follows: + +```python +pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V").to("cuda") +``` + +Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`: + +```python +pipe.transformer.to(memory_format=torch.channels_last) +``` + +Compile the components and run inference: + +```python +pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True) + +# CogVideoX works well with long and well-described prompts +prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance." +video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] +``` + +The [T2V benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are: + +``` +Without torch.compile(): Average inference time: 96.89 seconds. +With torch.compile(): Average inference time: 76.27 seconds. +``` + + +## CogVideoXPipeline + +[[autodoc]] CogVideoXPipeline + - all + - __call__ + + diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 500a440f4203..65156d2935dc 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -15,6 +15,7 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union +import PIL import numpy as np import torch from transformers import LlamaTokenizer @@ -146,7 +147,7 @@ class OmniGenPipeline( model_cpu_offload_seq = "transformer->vae" _optional_components = [] - _callback_tensor_inputs = ["latents", "prompt_embeds"] + _callback_tensor_inputs = ["latents", "condition_tokens"] def __init__( self, @@ -361,7 +362,7 @@ def interrupt(self): def __call__( self, prompt: Union[str, List[str]], - input_images: Optional[Union[List[str], List[List[str]]]] = None, + input_images: Optional[Union[List[str], List[PIL.Image.Image], List[List[str]], List[List[PIL.Image.Image]]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -527,11 +528,6 @@ def __call__( latents, ) - - generator = torch.Generator(device=device).manual_seed(0) - latents = torch.randn(1, 4, height//8, width//8, device=device, generator=generator).to(self.transformer.dtype) - # latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype) - # 7. Prepare OmniGenCache num_tokens_for_output_img = latents.size(-1) * latents.size(-2) // (self.transformer.patch_size ** 2) cache = OmniGenCache(num_tokens_for_output_img, offload_kv_cache) if use_kv_cache else None diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py index b425d84269df..545c9b001c7d 100644 --- a/src/diffusers/pipelines/omnigen/processor_omnigen.py +++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py @@ -57,7 +57,8 @@ def __init__(self, self.collator = OmniGenCollator() def process_image(self, image): - image = Image.open(image).convert('RGB') + if isinstance(image, str): + image = Image.open(image).convert('RGB') return self.image_transform(image) def process_multi_modal_prompt(self, text, input_images): diff --git a/tests/pipelines/omnigen/test_pipeline_flux.py b/tests/pipelines/omnigen/test_pipeline_flux.py deleted file mode 100644 index df9021ee0adb..000000000000 --- a/tests/pipelines/omnigen/test_pipeline_flux.py +++ /dev/null @@ -1,298 +0,0 @@ -import gc -import unittest - -import numpy as np -import pytest -import torch -from huggingface_hub import hf_hub_download -from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel - -from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel -from diffusers.utils.testing_utils import ( - numpy_cosine_similarity_distance, - require_big_gpu_with_torch_cuda, - slow, - torch_device, -) - -from ..test_pipelines_common import ( - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) - - -class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin): - pipeline_class = FluxPipeline - params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) - batch_params = frozenset(["prompt"]) - - # there is no xformers processor for Flux - test_xformers_attention = False - - def get_dummy_components(self): - torch.manual_seed(0) - transformer = FluxTransformer2DModel( - patch_size=1, - in_channels=4, - num_layers=1, - num_single_layers=1, - attention_head_dim=16, - num_attention_heads=2, - joint_attention_dim=32, - pooled_projection_dim=32, - axes_dims_rope=[4, 4, 8], - ) - clip_text_encoder_config = CLIPTextConfig( - bos_token_id=0, - eos_token_id=2, - hidden_size=32, - intermediate_size=37, - layer_norm_eps=1e-05, - num_attention_heads=4, - num_hidden_layers=5, - pad_token_id=1, - vocab_size=1000, - hidden_act="gelu", - projection_dim=32, - ) - - torch.manual_seed(0) - text_encoder = CLIPTextModel(clip_text_encoder_config) - - torch.manual_seed(0) - text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") - - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") - - torch.manual_seed(0) - vae = AutoencoderKL( - sample_size=32, - in_channels=3, - out_channels=3, - block_out_channels=(4,), - layers_per_block=1, - latent_channels=1, - norm_num_groups=1, - use_quant_conv=False, - use_post_quant_conv=False, - shift_factor=0.0609, - scaling_factor=1.5035, - ) - - scheduler = FlowMatchEulerDiscreteScheduler() - - return { - "scheduler": scheduler, - "text_encoder": text_encoder, - "text_encoder_2": text_encoder_2, - "tokenizer": tokenizer, - "tokenizer_2": tokenizer_2, - "transformer": transformer, - "vae": vae, - } - - def get_dummy_inputs(self, device, seed=0): - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) - else: - generator = torch.Generator(device="cpu").manual_seed(seed) - - inputs = { - "prompt": "A painting of a squirrel eating a burger", - "generator": generator, - "num_inference_steps": 2, - "guidance_scale": 5.0, - "height": 8, - "width": 8, - "max_sequence_length": 48, - "output_type": "np", - } - return inputs - - def test_flux_different_prompts(self): - pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) - - inputs = self.get_dummy_inputs(torch_device) - output_same_prompt = pipe(**inputs).images[0] - - inputs = self.get_dummy_inputs(torch_device) - inputs["prompt_2"] = "a different prompt" - output_different_prompts = pipe(**inputs).images[0] - - max_diff = np.abs(output_same_prompt - output_different_prompts).max() - - # Outputs should be different here - # For some reasons, they don't show large differences - assert max_diff > 1e-6 - - def test_flux_prompt_embeds(self): - pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) - inputs = self.get_dummy_inputs(torch_device) - - output_with_prompt = pipe(**inputs).images[0] - - inputs = self.get_dummy_inputs(torch_device) - prompt = inputs.pop("prompt") - - (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( - prompt, - prompt_2=None, - device=torch_device, - max_sequence_length=inputs["max_sequence_length"], - ) - output_with_embeds = pipe( - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - **inputs, - ).images[0] - - max_diff = np.abs(output_with_prompt - output_with_embeds).max() - assert max_diff < 1e-4 - - def test_fused_qkv_projections(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - original_image_slice = image[0, -3:, -3:, -1] - - # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added - # to the pipeline level. - pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist( - pipe.transformer - ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." - - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - image_slice_fused = image[0, -3:, -3:, -1] - - pipe.transformer.unfuse_qkv_projections() - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs).images - image_slice_disabled = image[0, -3:, -3:, -1] - - assert np.allclose( - original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 - ), "Fusion of QKV projections shouldn't affect the outputs." - assert np.allclose( - image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 - ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." - assert np.allclose( - original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 - ), "Original outputs should match when fused QKV projections are disabled." - - def test_flux_image_output_shape(self): - pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) - inputs = self.get_dummy_inputs(torch_device) - - height_width_pairs = [(32, 32), (72, 57)] - for height, width in height_width_pairs: - expected_height = height - height % (pipe.vae_scale_factor * 2) - expected_width = width - width % (pipe.vae_scale_factor * 2) - - inputs.update({"height": height, "width": width}) - image = pipe(**inputs).images[0] - output_height, output_width, _ = image.shape - assert (output_height, output_width) == (expected_height, expected_width) - - -@slow -@require_big_gpu_with_torch_cuda -@pytest.mark.big_gpu_with_torch_cuda -class FluxPipelineSlowTests(unittest.TestCase): - pipeline_class = FluxPipeline - repo_id = "black-forest-labs/FLUX.1-schnell" - - def setUp(self): - super().setUp() - gc.collect() - torch.cuda.empty_cache() - - def tearDown(self): - super().tearDown() - gc.collect() - torch.cuda.empty_cache() - - def get_inputs(self, device, seed=0): - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) - else: - generator = torch.Generator(device="cpu").manual_seed(seed) - - prompt_embeds = torch.load( - hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt") - ) - pooled_prompt_embeds = torch.load( - hf_hub_download( - repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt" - ) - ) - return { - "prompt_embeds": prompt_embeds, - "pooled_prompt_embeds": pooled_prompt_embeds, - "num_inference_steps": 2, - "guidance_scale": 0.0, - "max_sequence_length": 256, - "output_type": "np", - "generator": generator, - } - - def test_flux_inference(self): - pipe = self.pipeline_class.from_pretrained( - self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None - ) - pipe.enable_model_cpu_offload() - - inputs = self.get_inputs(torch_device) - - image = pipe(**inputs).images[0] - image_slice = image[0, :10, :10] - expected_slice = np.array( - [ - 0.3242, - 0.3203, - 0.3164, - 0.3164, - 0.3125, - 0.3125, - 0.3281, - 0.3242, - 0.3203, - 0.3301, - 0.3262, - 0.3242, - 0.3281, - 0.3242, - 0.3203, - 0.3262, - 0.3262, - 0.3164, - 0.3262, - 0.3281, - 0.3184, - 0.3281, - 0.3281, - 0.3203, - 0.3281, - 0.3281, - 0.3164, - 0.3320, - 0.3320, - 0.3203, - ], - dtype=np.float32, - ) - - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) - - assert max_diff < 1e-4 diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py new file mode 100644 index 000000000000..73124576ee04 --- /dev/null +++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py @@ -0,0 +1,269 @@ +import gc +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenTransformer2DModel, OmniGenPipeline +from diffusers.utils.testing_utils import ( + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin + + +class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = OmniGenPipeline + params = frozenset( + [ + "prompt", + "guidance_scale", + ] + ) + batch_params = frozenset(["prompt", ]) + + def get_dummy_components(self): + torch.manual_seed(0) + + transformer_config = { + "_name_or_path": "Phi-3-vision-128k-instruct", + "architectures": [ + "Phi3ForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 3072, + "initializer_range": 0.02, + "intermediate_size": 8192, + "max_position_embeddings": 131072, + "model_type": "phi3", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "original_max_position_embeddings": 4096, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "long_factor": [ + 1.0299999713897705, + 1.0499999523162842, + 1.0499999523162842, + 1.0799999237060547, + 1.2299998998641968, + 1.2299998998641968, + 1.2999999523162842, + 1.4499999284744263, + 1.5999999046325684, + 1.6499998569488525, + 1.8999998569488525, + 2.859999895095825, + 3.68999981880188, + 5.419999599456787, + 5.489999771118164, + 5.489999771118164, + 9.09000015258789, + 11.579999923706055, + 15.65999984741211, + 15.769999504089355, + 15.789999961853027, + 18.360000610351562, + 21.989999771118164, + 23.079999923706055, + 30.009998321533203, + 32.35000228881836, + 32.590003967285156, + 35.56000518798828, + 39.95000457763672, + 53.840003967285156, + 56.20000457763672, + 57.95000457763672, + 59.29000473022461, + 59.77000427246094, + 59.920005798339844, + 61.190006256103516, + 61.96000671386719, + 62.50000762939453, + 63.3700065612793, + 63.48000717163086, + 63.48000717163086, + 63.66000747680664, + 63.850006103515625, + 64.08000946044922, + 64.760009765625, + 64.80001068115234, + 64.81001281738281, + 64.81001281738281 + ], + "short_factor": [ + 1.05, + 1.05, + 1.05, + 1.1, + 1.1, + 1.1, + 1.2500000000000002, + 1.2500000000000002, + 1.4000000000000004, + 1.4500000000000004, + 1.5500000000000005, + 1.8500000000000008, + 1.9000000000000008, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.1000000000000005, + 2.1000000000000005, + 2.2, + 2.3499999999999996, + 2.3499999999999996, + 2.3499999999999996, + 2.3499999999999996, + 2.3999999999999995, + 2.3999999999999995, + 2.6499999999999986, + 2.6999999999999984, + 2.8999999999999977, + 2.9499999999999975, + 3.049999999999997, + 3.049999999999997, + 3.049999999999997 + ], + "type": "su" + }, + "rope_theta": 10000.0, + "sliding_window": 131072, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.1", + "use_cache": True, + "vocab_size": 32064, + "_attn_implementation": "sdpa" + } + transformer = OmniGenTransformer2DModel( + transformer_config=transformer_config, + patch_size=2, + in_channels=4, + pos_embed_max_size=192, + ) + + + torch.manual_seed(0) + vae = AutoencoderKL() + + scheduler = FlowMatchEulerDiscreteScheduler() + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + + components = { + "transformer": transformer.eval(), + "vae": vae.eval(), + "scheduler": scheduler, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + "height": 16, + "width": 16, + } + return inputs + + def test_inference(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + generated_image = pipe(**inputs).images[0] + + self.assertEqual(generated_image.shape, (1, 3, 16, 16)) + + + + +@slow +@require_torch_gpu +class OmniGenPipelineSlowTests(unittest.TestCase): + pipeline_class = OmniGenPipeline + repo_id = "Shitao/OmniGen-v1-diffusers" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + return { + "prompt": "A photo of a cat", + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + "generator": generator, + } + + def test_omnigen_inference(self): + pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16) + pipe.enable_model_cpu_offload() + + inputs = self.get_inputs(torch_device) + + image = pipe(**inputs).images[0] + image_slice = image[0, :10, :10] + expected_slice = np.array( + [ + [0.17773438, 0.18554688, 0.22070312], + [0.046875, 0.06640625, 0.10351562], + [0.0, 0.0, 0.02148438], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + dtype=np.float32, + ) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) + + assert max_diff < 1e-4 From 308766cf99cdc85464fba593b97631e07f54adc2 Mon Sep 17 00:00:00 2001 From: staoxiao <2906698981@qq.com> Date: Wed, 4 Dec 2024 15:24:03 +0800 Subject: [PATCH 07/55] update omnigenpipeline --- .../transformers/transformer_omnigen.py | 56 +++++++++++----- .../pipelines/omnigen/pipeline_omnigen.py | 67 ++++++------------- test.py | 11 +-- .../omnigen/test_pipeline_omnigen.py | 43 +++++++----- 4 files changed, 93 insertions(+), 84 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index 031fc014e77a..64790a2024e5 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -281,20 +281,6 @@ def unpatchify(self, x, h, w): imgs = x.reshape(shape=(x.shape[0], c, h, w)) return imgs - def prepare_condition_embeddings(self, input_ids, input_img_latents, input_image_sizes): - condition_embeds = None - if input_img_latents is not None: - input_latents = self.patch_embedding(input_img_latents, is_input_images=True) - if input_ids is not None: - condition_embeds = self.llm.embed_tokens(input_ids).clone() - input_img_inx = 0 - for b_inx in input_image_sizes.keys(): - for start_inx, end_inx in input_image_sizes[b_inx]: - condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx] - input_img_inx += 1 - if input_img_latents is not None: - assert input_img_inx == len(input_latents) - return condition_embeds @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors @@ -359,11 +345,46 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value + + def get_multimodal_embeddings(self, + input_ids: torch.Tensor, + input_img_latents: List[torch.Tensor], + input_image_sizes: Dict, + ): + """ + get the multi-modal conditional embeddings + Args: + input_ids: a sequence of text id + input_img_latents: continues embedding of input images + input_image_sizes: the index of the input image in the input_ids sequence. + + Returns: torch.Tensor + + """ + input_img_latents = [x.to(self.dtype) for x in input_img_latents] + condition_tokens = None + if input_ids is not None: + condition_tokens = self.llm.embed_tokens(input_ids) + input_img_inx = 0 + if input_img_latents is not None: + input_image_tokens = self.patch_embedding(input_img_latents, + is_input_image=True) + + for b_inx in input_image_sizes.keys(): + for start_inx, end_inx in input_image_sizes[b_inx]: + # replace the placeholder in text tokens with the image embedding. + condition_tokens[b_inx, start_inx: end_inx] = input_image_tokens[input_img_inx].to( + condition_tokens.dtype) + input_img_inx += 1 + + return condition_tokens def forward(self, hidden_states: torch.Tensor, timestep: Union[int, float, torch.LongTensor], - condition_tokens: torch.Tensor, + input_ids: torch.Tensor, + input_img_latents: List[torch.Tensor], + input_image_sizes: Dict[int, List[int]], attention_mask: torch.Tensor, position_ids: torch.Tensor, past_key_values: DynamicCache = None, @@ -386,13 +407,16 @@ def forward(self, logger.warning( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) - height, width = hidden_states.size()[-2:] hidden_states = self.patch_embedding(hidden_states, is_input_image=False) num_tokens_for_output_image = hidden_states.size(1) time_token = self.time_token(timestep, dtype=hidden_states.dtype).unsqueeze(1) + condition_tokens = self.get_multimodal_embeddings(input_ids=input_ids, + input_img_latents=input_img_latents, + input_image_sizes=input_image_sizes, + ) if condition_tokens is not None: input_emb = torch.cat([condition_tokens, time_token, hidden_states], dim=1) else: diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 65156d2935dc..ff10cade50f3 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -147,7 +147,7 @@ class OmniGenPipeline( model_cpu_offload_seq = "transformer->vae" _optional_components = [] - _callback_tensor_inputs = ["latents", "condition_tokens"] + _callback_tensor_inputs = ["latents", "input_images_latents"] def __init__( self, @@ -199,43 +199,6 @@ def encod_input_iamges( input_img_latents.append(img) return input_img_latents - def get_multimodal_embeddings(self, - input_ids: torch.Tensor, - input_img_latents: List[torch.Tensor], - input_image_sizes: Dict, - device: Optional[torch.device] = None, - ): - """ - get the multi-modal conditional embeddings - Args: - input_ids: a sequence of text id - input_img_latents: continues embedding of input images - input_image_sizes: the index of the input image in the input_ids sequence. - device: - - Returns: torch.Tensor - - """ - device = device or self._execution_device - - input_img_latents = [x.to(self.transformer.dtype) for x in input_img_latents] - - condition_tokens = None - if input_ids is not None: - condition_tokens = self.transformer.llm.embed_tokens(input_ids.to(device)).clone() - input_img_inx = 0 - if input_img_latents is not None: - input_image_tokens = self.transformer.patch_embedding(input_img_latents, - is_input_image=True) - - for b_inx in input_image_sizes.keys(): - for start_inx, end_inx in input_image_sizes[b_inx]: - # replace the placeholder in text tokens with the image embedding. - condition_tokens[b_inx, start_inx: end_inx] = input_image_tokens[input_img_inx].to( - condition_tokens.dtype) - input_img_inx += 1 - - return condition_tokens def check_inputs( self, @@ -243,6 +206,8 @@ def check_inputs( input_images, height, width, + use_kv_cache, + offload_kv_cache, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): @@ -263,6 +228,12 @@ def check_inputs( logger.warning( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) + + if use_kv_cache and offload_kv_cache: + if not torch.cuda.is_available(): + raise ValueError( + f"Don't fine avaliable GPUs. `offload_kv_cache` can't be used when there is no GPU. please set it to False: `use_kv_cache=False, offload_kv_cache=False`" + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -475,6 +446,8 @@ def __call__( input_images, height, width, + use_kv_cache, + offload_kv_cache, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) @@ -496,16 +469,12 @@ def __call__( width=width, use_img_cfg=use_img_cfg, use_input_image_size_as_output=use_input_image_size_as_output) + processed_data['input_ids'] = processed_data['input_ids'].to(device) processed_data['attention_mask'] = processed_data['attention_mask'].to(device) processed_data['position_ids'] = processed_data['position_ids'].to(device) - # 4. Encode input images and obtain multi-modal conditional embeddings + # 4. Encode input images input_img_latents = self.encod_input_iamges(processed_data['input_pixel_values'], device=device) - condition_tokens = self.get_multimodal_embeddings(input_ids=processed_data['input_ids'], - input_img_latents=input_img_latents, - input_image_sizes=processed_data['input_image_sizes'], - device=device, - ) # 5. Prepare timesteps sigmas = np.linspace(1, 0, num_inference_steps+1)[:num_inference_steps] @@ -522,7 +491,7 @@ def __call__( latent_channels, height, width, - condition_tokens.dtype, + self.transformer.dtype, device, generator, latents, @@ -545,7 +514,9 @@ def __call__( noise_pred, cache = self.transformer( hidden_states=latent_model_input, timestep=timestep, - condition_tokens=condition_tokens, + input_ids=processed_data['input_ids'], + input_img_latents=input_img_latents, + input_image_sizes=processed_data['input_image_sizes'], attention_mask=processed_data['attention_mask'], position_ids=processed_data['position_ids'], attention_kwargs=attention_kwargs, @@ -556,8 +527,8 @@ def __call__( # if use kv cache, don't need attention mask and position ids of condition tokens for next step if use_kv_cache: - if condition_tokens is not None: - condition_tokens = None + if processed_data['input_ids'] is not None: + processed_data['input_ids'] = None processed_data['attention_mask'] = processed_data['attention_mask'][..., -(num_tokens_for_output_img + 1):, :] # +1 is for the timestep token processed_data['position_ids'] = processed_data['position_ids'][:, -(num_tokens_for_output_img + 1):] diff --git a/test.py b/test.py index fe1410122acd..b27d99a1066b 100644 --- a/test.py +++ b/test.py @@ -1,5 +1,5 @@ -# import os -# os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2' +import os +os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2' # from huggingface_hub import snapshot_download @@ -47,10 +47,13 @@ # model.load_state_dict(new_ckpt) +from tests.pipelines.omnigen.test_pipeline_omnigen import OmniGenPipelineFastTests, OmniGenPipelineSlowTests +test1 = OmniGenPipelineFastTests() +test1.test_inference() - - +test2 = OmniGenPipelineSlowTests() +test2.test_omnigen_inference() diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py index 73124576ee04..93870f9da31d 100644 --- a/tests/pipelines/omnigen/test_pipeline_omnigen.py +++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py @@ -169,10 +169,20 @@ def get_dummy_components(self): torch.manual_seed(0) - vae = AutoencoderKL() + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4, 4, 4, 4), + layers_per_block=1, + latent_channels=4, + norm_num_groups=1, + up_block_types = ["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + ) scheduler = FlowMatchEulerDiscreteScheduler() tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + # tokenizer = AutoTokenizer.from_pretrained("Shitao/OmniGen-v1") components = { "transformer": transformer.eval(), @@ -192,10 +202,12 @@ def get_dummy_inputs(self, device, seed=0): "prompt": "A painting of a squirrel eating a burger", "generator": generator, "num_inference_steps": 2, - "guidance_scale": 5.0, + "guidance_scale": 3.0, "output_type": "np", "height": 16, "width": 16, + "use_kv_cache": False, + "offload_kv_cache": False, } return inputs @@ -205,7 +217,7 @@ def test_inference(self): inputs = self.get_dummy_inputs(torch_device) generated_image = pipe(**inputs).images[0] - self.assertEqual(generated_image.shape, (1, 3, 16, 16)) + self.assertEqual(generated_image.shape, (16, 16, 3)) @@ -235,7 +247,7 @@ def get_inputs(self, device, seed=0): return { "prompt": "A photo of a cat", "num_inference_steps": 2, - "guidance_scale": 5.0, + "guidance_scale": 2.5, "output_type": "np", "generator": generator, } @@ -248,19 +260,18 @@ def test_omnigen_inference(self): image = pipe(**inputs).images[0] image_slice = image[0, :10, :10] + expected_slice = np.array( - [ - [0.17773438, 0.18554688, 0.22070312], - [0.046875, 0.06640625, 0.10351562], - [0.0, 0.0, 0.02148438], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - ], + [[0.25806782, 0.28012177, 0.27807158], + [0.25740036, 0.2677201, 0.26857468], + [0.258638, 0.27035138, 0.26633185], + [0.2541029, 0.2636156, 0.26373306], + [0.24975497, 0.2608987, 0.2617477 ], + [0.25102, 0.26192215, 0.262023 ], + [0.24452701, 0.25664824, 0.259144 ], + [0.2419573, 0.2574909, 0.25996095], + [0.23953134, 0.25292695, 0.25652167], + [0.23523712, 0.24710432, 0.25460982]], dtype=np.float32, ) From 4c5e8c572118ede4bf402648c16616deea35856e Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Thu, 5 Dec 2024 15:27:59 +0800 Subject: [PATCH 08/55] update docs --- .../en/using-diffusers/multimodal2img.md | 605 ++++++++++++++++++ docs/source/en/using-diffusers/omnigen.md | 70 ++ 2 files changed, 675 insertions(+) create mode 100644 docs/source/en/using-diffusers/multimodal2img.md create mode 100644 docs/source/en/using-diffusers/omnigen.md diff --git a/docs/source/en/using-diffusers/multimodal2img.md b/docs/source/en/using-diffusers/multimodal2img.md new file mode 100644 index 000000000000..4618731830df --- /dev/null +++ b/docs/source/en/using-diffusers/multimodal2img.md @@ -0,0 +1,605 @@ + + +# Image-to-image + +[[open-in-colab]] + +Image-to-image is similar to [text-to-image](conditional_image_generation), but in addition to a prompt, you can also pass an initial image as a starting point for the diffusion process. The initial image is encoded to latent space and noise is added to it. Then the latent diffusion model takes a prompt and the noisy latent image, predicts the added noise, and removes the predicted noise from the initial latent image to get the new latent image. Lastly, a decoder decodes the new latent image back into an image. + +With 🤗 Diffusers, this is as easy as 1-2-3: + +1. Load a checkpoint into the [`AutoPipelineForImage2Image`] class; this pipeline automatically handles loading the correct pipeline class based on the checkpoint: + +```py +import torch +from diffusers import AutoPipelineForImage2Image +from diffusers.utils import load_image, make_image_grid + +pipeline = AutoPipelineForImage2Image.from_pretrained( + "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True +) +pipeline.enable_model_cpu_offload() +# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed +pipeline.enable_xformers_memory_efficient_attention() +``` + + + +You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, then you don't need to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention). + + + +2. Load an image to pass to the pipeline: + +```py +init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png") +``` + +3. Pass a prompt and image to the pipeline to generate an image: + +```py +prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" +image = pipeline(prompt, image=init_image).images[0] +make_image_grid([init_image, image], rows=1, cols=2) +``` + +
+
+ +
initial image
+
+
+ +
generated image
+
+
+ +## Popular models + +The most popular image-to-image models are [Stable Diffusion v1.5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5), [Stable Diffusion XL (SDXL)](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), and [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder). The results from the Stable Diffusion and Kandinsky models vary due to their architecture differences and training process; you can generally expect SDXL to produce higher quality images than Stable Diffusion v1.5. Let's take a quick look at how to use each of these models and compare their results. + +### Stable Diffusion v1.5 + +Stable Diffusion v1.5 is a latent diffusion model initialized from an earlier checkpoint, and further finetuned for 595K steps on 512x512 images. To use this pipeline for image-to-image, you'll need to prepare an initial image to pass to the pipeline. Then you can pass a prompt and the image to the pipeline to generate a new image: + +```py +import torch +from diffusers import AutoPipelineForImage2Image +from diffusers.utils import make_image_grid, load_image + +pipeline = AutoPipelineForImage2Image.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True +) +pipeline.enable_model_cpu_offload() +# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed +pipeline.enable_xformers_memory_efficient_attention() + +# prepare image +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" +init_image = load_image(url) + +prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" + +# pass prompt and image to pipeline +image = pipeline(prompt, image=init_image).images[0] +make_image_grid([init_image, image], rows=1, cols=2) +``` + +
+
+ +
initial image
+
+
+ +
generated image
+
+
+ +### Stable Diffusion XL (SDXL) + +SDXL is a more powerful version of the Stable Diffusion model. It uses a larger base model, and an additional refiner model to increase the quality of the base model's output. Read the [SDXL](sdxl) guide for a more detailed walkthrough of how to use this model, and other techniques it uses to produce high quality images. + +```py +import torch +from diffusers import AutoPipelineForImage2Image +from diffusers.utils import make_image_grid, load_image + +pipeline = AutoPipelineForImage2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True +) +pipeline.enable_model_cpu_offload() +# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed +pipeline.enable_xformers_memory_efficient_attention() + +# prepare image +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png" +init_image = load_image(url) + +prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" + +# pass prompt and image to pipeline +image = pipeline(prompt, image=init_image, strength=0.5).images[0] +make_image_grid([init_image, image], rows=1, cols=2) +``` + +
+
+ +
initial image
+
+
+ +
generated image
+
+
+ +### Kandinsky 2.2 + +The Kandinsky model is different from the Stable Diffusion models because it uses an image prior model to create image embeddings. The embeddings help create a better alignment between text and images, allowing the latent diffusion model to generate better images. + +The simplest way to use Kandinsky 2.2 is: + +```py +import torch +from diffusers import AutoPipelineForImage2Image +from diffusers.utils import make_image_grid, load_image + +pipeline = AutoPipelineForImage2Image.from_pretrained( + "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True +) +pipeline.enable_model_cpu_offload() +# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed +pipeline.enable_xformers_memory_efficient_attention() + +# prepare image +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" +init_image = load_image(url) + +prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" + +# pass prompt and image to pipeline +image = pipeline(prompt, image=init_image).images[0] +make_image_grid([init_image, image], rows=1, cols=2) +``` + +
+
+ +
initial image
+
+
+ +
generated image
+
+
+ +## Configure pipeline parameters + +There are several important parameters you can configure in the pipeline that'll affect the image generation process and image quality. Let's take a closer look at what these parameters do and how changing them affects the output. + +### Strength + +`strength` is one of the most important parameters to consider and it'll have a huge impact on your generated image. It determines how much the generated image resembles the initial image. In other words: + +- 📈 a higher `strength` value gives the model more "creativity" to generate an image that's different from the initial image; a `strength` value of 1.0 means the initial image is more or less ignored +- 📉 a lower `strength` value means the generated image is more similar to the initial image + +The `strength` and `num_inference_steps` parameters are related because `strength` determines the number of noise steps to add. For example, if the `num_inference_steps` is 50 and `strength` is 0.8, then this means adding 40 (50 * 0.8) steps of noise to the initial image and then denoising for 40 steps to get the newly generated image. + +```py +import torch +from diffusers import AutoPipelineForImage2Image +from diffusers.utils import make_image_grid, load_image + +pipeline = AutoPipelineForImage2Image.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True +) +pipeline.enable_model_cpu_offload() +# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed +pipeline.enable_xformers_memory_efficient_attention() + +# prepare image +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" +init_image = load_image(url) + +prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" + +# pass prompt and image to pipeline +image = pipeline(prompt, image=init_image, strength=0.8).images[0] +make_image_grid([init_image, image], rows=1, cols=2) +``` + +
+
+ +
strength = 0.4
+
+
+ +
strength = 0.6
+
+
+ +
strength = 1.0
+
+
+ +### Guidance scale + +The `guidance_scale` parameter is used to control how closely aligned the generated image and text prompt are. A higher `guidance_scale` value means your generated image is more aligned with the prompt, while a lower `guidance_scale` value means your generated image has more space to deviate from the prompt. + +You can combine `guidance_scale` with `strength` for even more precise control over how expressive the model is. For example, combine a high `strength + guidance_scale` for maximum creativity or use a combination of low `strength` and low `guidance_scale` to generate an image that resembles the initial image but is not as strictly bound to the prompt. + +```py +import torch +from diffusers import AutoPipelineForImage2Image +from diffusers.utils import make_image_grid, load_image + +pipeline = AutoPipelineForImage2Image.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True +) +pipeline.enable_model_cpu_offload() +# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed +pipeline.enable_xformers_memory_efficient_attention() + +# prepare image +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" +init_image = load_image(url) + +prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" + +# pass prompt and image to pipeline +image = pipeline(prompt, image=init_image, guidance_scale=8.0).images[0] +make_image_grid([init_image, image], rows=1, cols=2) +``` + +
+
+ +
guidance_scale = 0.1
+
+
+ +
guidance_scale = 5.0
+
+
+ +
guidance_scale = 10.0
+
+
+ +### Negative prompt + +A negative prompt conditions the model to *not* include things in an image, and it can be used to improve image quality or modify an image. For example, you can improve image quality by including negative prompts like "poor details" or "blurry" to encourage the model to generate a higher quality image. Or you can modify an image by specifying things to exclude from an image. + +```py +import torch +from diffusers import AutoPipelineForImage2Image +from diffusers.utils import make_image_grid, load_image + +pipeline = AutoPipelineForImage2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True +) +pipeline.enable_model_cpu_offload() +# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed +pipeline.enable_xformers_memory_efficient_attention() + +# prepare image +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" +init_image = load_image(url) + +prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" +negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy" + +# pass prompt and image to pipeline +image = pipeline(prompt, negative_prompt=negative_prompt, image=init_image).images[0] +make_image_grid([init_image, image], rows=1, cols=2) +``` + +
+
+ +
negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
+
+
+ +
negative_prompt = "jungle"
+
+
+ +## Chained image-to-image pipelines + +There are some other interesting ways you can use an image-to-image pipeline aside from just generating an image (although that is pretty cool too). You can take it a step further and chain it with other pipelines. + +### Text-to-image-to-image + +Chaining a text-to-image and image-to-image pipeline allows you to generate an image from text and use the generated image as the initial image for the image-to-image pipeline. This is useful if you want to generate an image entirely from scratch. For example, let's chain a Stable Diffusion and a Kandinsky model. + +Start by generating an image with the text-to-image pipeline: + +```py +from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image +import torch +from diffusers.utils import make_image_grid + +pipeline = AutoPipelineForText2Image.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True +) +pipeline.enable_model_cpu_offload() +# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed +pipeline.enable_xformers_memory_efficient_attention() + +text2image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k").images[0] +text2image +``` + +Now you can pass this generated image to the image-to-image pipeline: + +```py +pipeline = AutoPipelineForImage2Image.from_pretrained( + "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True +) +pipeline.enable_model_cpu_offload() +# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed +pipeline.enable_xformers_memory_efficient_attention() + +image2image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=text2image).images[0] +make_image_grid([text2image, image2image], rows=1, cols=2) +``` + +### Image-to-image-to-image + +You can also chain multiple image-to-image pipelines together to create more interesting images. This can be useful for iteratively performing style transfer on an image, generating short GIFs, restoring color to an image, or restoring missing areas of an image. + +Start by generating an image: + +```py +import torch +from diffusers import AutoPipelineForImage2Image +from diffusers.utils import make_image_grid, load_image + +pipeline = AutoPipelineForImage2Image.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True +) +pipeline.enable_model_cpu_offload() +# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed +pipeline.enable_xformers_memory_efficient_attention() + +# prepare image +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" +init_image = load_image(url) + +prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" + +# pass prompt and image to pipeline +image = pipeline(prompt, image=init_image, output_type="latent").images[0] +``` + + + +It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE. + + + +Pass the latent output from this pipeline to the next pipeline to generate an image in a [comic book art style](https://huggingface.co/ogkalu/Comic-Diffusion): + +```py +pipeline = AutoPipelineForImage2Image.from_pretrained( + "ogkalu/Comic-Diffusion", torch_dtype=torch.float16 +) +pipeline.enable_model_cpu_offload() +# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed +pipeline.enable_xformers_memory_efficient_attention() + +# need to include the token "charliebo artstyle" in the prompt to use this checkpoint +image = pipeline("Astronaut in a jungle, charliebo artstyle", image=image, output_type="latent").images[0] +``` + +Repeat one more time to generate the final image in a [pixel art style](https://huggingface.co/kohbanye/pixel-art-style): + +```py +pipeline = AutoPipelineForImage2Image.from_pretrained( + "kohbanye/pixel-art-style", torch_dtype=torch.float16 +) +pipeline.enable_model_cpu_offload() +# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed +pipeline.enable_xformers_memory_efficient_attention() + +# need to include the token "pixelartstyle" in the prompt to use this checkpoint +image = pipeline("Astronaut in a jungle, pixelartstyle", image=image).images[0] +make_image_grid([init_image, image], rows=1, cols=2) +``` + +### Image-to-upscaler-to-super-resolution + +Another way you can chain your image-to-image pipeline is with an upscaler and super-resolution pipeline to really increase the level of details in an image. + +Start with an image-to-image pipeline: + +```py +import torch +from diffusers import AutoPipelineForImage2Image +from diffusers.utils import make_image_grid, load_image + +pipeline = AutoPipelineForImage2Image.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True +) +pipeline.enable_model_cpu_offload() +# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed +pipeline.enable_xformers_memory_efficient_attention() + +# prepare image +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" +init_image = load_image(url) + +prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" + +# pass prompt and image to pipeline +image_1 = pipeline(prompt, image=init_image, output_type="latent").images[0] +``` + + + +It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in *latent* space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE. + + + +Chain it to an upscaler pipeline to increase the image resolution: + +```py +from diffusers import StableDiffusionLatentUpscalePipeline + +upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained( + "stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True +) +upscaler.enable_model_cpu_offload() +upscaler.enable_xformers_memory_efficient_attention() + +image_2 = upscaler(prompt, image=image_1, output_type="latent").images[0] +``` + +Finally, chain it to a super-resolution pipeline to further enhance the resolution: + +```py +from diffusers import StableDiffusionUpscalePipeline + +super_res = StableDiffusionUpscalePipeline.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True +) +super_res.enable_model_cpu_offload() +super_res.enable_xformers_memory_efficient_attention() + +image_3 = super_res(prompt, image=image_2).images[0] +make_image_grid([init_image, image_3.resize((512, 512))], rows=1, cols=2) +``` + +## Control image generation + +Trying to generate an image that looks exactly the way you want can be difficult, which is why controlled generation techniques and models are so useful. While you can use the `negative_prompt` to partially control image generation, there are more robust methods like prompt weighting and ControlNets. + +### Prompt weighting + +Prompt weighting allows you to scale the representation of each concept in a prompt. For example, in a prompt like "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", you can choose to increase or decrease the embeddings of "astronaut" and "jungle". The [Compel](https://github.com/damian0815/compel) library provides a simple syntax for adjusting prompt weights and generating the embeddings. You can learn how to create the embeddings in the [Prompt weighting](weighted_prompts) guide. + +[`AutoPipelineForImage2Image`] has a `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter where you can pass the embeddings which replaces the `prompt` parameter. + +```py +from diffusers import AutoPipelineForImage2Image +import torch + +pipeline = AutoPipelineForImage2Image.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True +) +pipeline.enable_model_cpu_offload() +# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed +pipeline.enable_xformers_memory_efficient_attention() + +image = pipeline(prompt_embeds=prompt_embeds, # generated from Compel + negative_prompt_embeds=negative_prompt_embeds, # generated from Compel + image=init_image, +).images[0] +``` + +### ControlNet + +ControlNets provide a more flexible and accurate way to control image generation because you can use an additional conditioning image. The conditioning image can be a canny image, depth map, image segmentation, and even scribbles! Whatever type of conditioning image you choose, the ControlNet generates an image that preserves the information in it. + +For example, let's condition an image with a depth map to keep the spatial information in the image. + +```py +from diffusers.utils import load_image, make_image_grid + +# prepare image +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" +init_image = load_image(url) +init_image = init_image.resize((958, 960)) # resize to depth image dimensions +depth_image = load_image("https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png") +make_image_grid([init_image, depth_image], rows=1, cols=2) +``` + +Load a ControlNet model conditioned on depth maps and the [`AutoPipelineForImage2Image`]: + +```py +from diffusers import ControlNetModel, AutoPipelineForImage2Image +import torch + +controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16, variant="fp16", use_safetensors=True) +pipeline = AutoPipelineForImage2Image.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16", use_safetensors=True +) +pipeline.enable_model_cpu_offload() +# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed +pipeline.enable_xformers_memory_efficient_attention() +``` + +Now generate a new image conditioned on the depth map, initial image, and prompt: + +```py +prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" +image_control_net = pipeline(prompt, image=init_image, control_image=depth_image).images[0] +make_image_grid([init_image, depth_image, image_control_net], rows=1, cols=3) +``` + +
+
+ +
initial image
+
+
+ +
depth image
+
+
+ +
ControlNet image
+
+
+ +Let's apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion) to the image generated from the ControlNet by chaining it with an image-to-image pipeline: + +```py +pipeline = AutoPipelineForImage2Image.from_pretrained( + "nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16, +) +pipeline.enable_model_cpu_offload() +# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed +pipeline.enable_xformers_memory_efficient_attention() + +prompt = "elden ring style astronaut in a jungle" # include the token "elden ring style" in the prompt +negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy" + +image_elden_ring = pipeline(prompt, negative_prompt=negative_prompt, image=image_control_net, strength=0.45, guidance_scale=10.5).images[0] +make_image_grid([init_image, depth_image, image_control_net, image_elden_ring], rows=2, cols=2) +``` + +
+ +
+ +## Optimize + +Running diffusion models is computationally expensive and intensive, but with a few optimization tricks, it is entirely possible to run them on consumer and free-tier GPUs. For example, you can use a more memory-efficient form of attention such as PyTorch 2.0's [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) or [xFormers](../optimization/xformers) (you can use one or the other, but there's no need to use both). You can also offload the model to the GPU while the other pipeline components wait on the CPU. + +```diff ++ pipeline.enable_model_cpu_offload() ++ pipeline.enable_xformers_memory_efficient_attention() +``` + +With [`torch.compile`](../optimization/torch2.0#torchcompile), you can boost your inference speed even more by wrapping your UNet with it: + +```py +pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True) +``` + +To learn more, take a look at the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides. diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md new file mode 100644 index 000000000000..5b5e7119a94c --- /dev/null +++ b/docs/source/en/using-diffusers/omnigen.md @@ -0,0 +1,70 @@ + +# OmniGen + +OmniGen is an image generation model. Unlike existing text-to-image models, OmniGen is designed to handle a variety of tasks (e.g., text-to-image, image editing, controllable generation) within a single model. It has the following features: +- Minimalist model architecture, consisting of only a VAE and a transformer module, for joint modeling of text and images. +- Support for multimodal inputs. It can process any text-image mixed data as instructions for image generation, rather than relying solely on text. + +This guide will walk you through using OmniGen for various tasks and use cases. + +## Load model checkpoints +Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method. + +```py +import torch +from diffusers import OmniGenPipeline +pipe = OmniGenPipeline.from_pretrained( + "Shitao/OmniGen-v1-diffusers", + torch_dtype=torch.bfloat16 +) +``` + + +## Text-to-Image + + +## Text-to-image + +For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image. +You can try setting the `height` and `width` parameters to generate images with different size. + +```py +import torch +from diffusers import OmniGenPipeline +pipe = OmniGenPipeline.from_pretrained( + "Shitao/OmniGen-v1-diffusers", + torch_dtype=torch.bfloat16 +) + +prompt = "An elderly gentleman, with a serene expression, sits at the water's edge, a steaming cup of tea by his side. He is engrossed in his artwork, brush in hand, as he renders an oil painting on a canvas that's propped up against a small, weathered table. The sea breeze whispers through his silver hair, gently billowing his loose-fitting white shirt, while the salty air adds an intangible element to his masterpiece in progress. The scene is one of tranquility and inspiration, with the artist's canvas capturing the vibrant hues of the setting sun reflecting off the tranquil sea." +pipe.enable_model_cpu_offload() + +image = pipe( + prompt=prompt, + generator=torch.Generator(device="cuda").manual_seed(42), +).images[0] + +image +``` +
+ generated image of an astronaut in a jungle +
+For text-to-image, pass a text prompt. By default, CogVideoX generates a 720x480 video for the best results. + + + +## Optimization + +### inference speed + +### Memory \ No newline at end of file From d9f80fcdbdda912d21873b7117d4abaffaa3cb7c Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Thu, 5 Dec 2024 22:41:14 +0800 Subject: [PATCH 09/55] update docs --- docs/source/en/_toctree.yml | 8 ++++++++ .../en/api/models/omnigen_transformer.md | 19 +++++++++++++++++++ docs/source/en/using-diffusers/omnigen.md | 2 +- src/diffusers/models/transformers/__init__.py | 2 +- .../transformers/transformer_omnigen.py | 9 ++------- .../pipelines/lumina/pipeline_lumina.py | 2 +- 6 files changed, 32 insertions(+), 10 deletions(-) create mode 100644 docs/source/en/api/models/omnigen_transformer.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 2faabfec30ce..c181f0bd82a0 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -51,6 +51,8 @@ title: Text or image-to-video - local: using-diffusers/depth2img title: Depth-to-image + - local: using-diffusers/multimodal2img + title: Multimodal-to-image title: Generative tasks - sections: - local: using-diffusers/overview_techniques @@ -87,6 +89,8 @@ title: Kandinsky - local: using-diffusers/ip_adapter title: IP-Adapter + - local: using-diffusers/omnigen + title: OmniGen - local: using-diffusers/pag title: PAG - local: using-diffusers/controlnet @@ -274,6 +278,8 @@ title: LuminaNextDiT2DModel - local: api/models/mochi_transformer3d title: MochiTransformer3DModel + - local: api/models/omnigen_transformer + title: OmniGenTransformer2DModel - local: api/models/pixart_transformer2d title: PixArtTransformer2DModel - local: api/models/prior_transformer @@ -412,6 +418,8 @@ title: MultiDiffusion - local: api/pipelines/musicldm title: MusicLDM + - local: api/pipelines/omnigen + title: OmniGen - local: api/pipelines/pag title: PAG - local: api/pipelines/paint_by_example diff --git a/docs/source/en/api/models/omnigen_transformer.md b/docs/source/en/api/models/omnigen_transformer.md new file mode 100644 index 000000000000..d2df6c55e68b --- /dev/null +++ b/docs/source/en/api/models/omnigen_transformer.md @@ -0,0 +1,19 @@ + + +# OmniGenTransformer2DModel + +A Transformer model accept multi-modal instruction to generate image from [OmniGen](https://github.com/VectorSpaceLab/OmniGen/). + +## OmniGenTransformer2DModel + +[[autodoc]] OmniGenTransformer2DModel diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md index 5b5e7119a94c..ff398eb4f88e 100644 --- a/docs/source/en/using-diffusers/omnigen.md +++ b/docs/source/en/using-diffusers/omnigen.md @@ -46,7 +46,7 @@ pipe = OmniGenPipeline.from_pretrained( torch_dtype=torch.bfloat16 ) -prompt = "An elderly gentleman, with a serene expression, sits at the water's edge, a steaming cup of tea by his side. He is engrossed in his artwork, brush in hand, as he renders an oil painting on a canvas that's propped up against a small, weathered table. The sea breeze whispers through his silver hair, gently billowing his loose-fitting white shirt, while the salty air adds an intangible element to his masterpiece in progress. The scene is one of tranquility and inspiration, with the artist's canvas capturing the vibrant hues of the setting sun reflecting off the tranquil sea." +prompt = "A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD." pipe.enable_model_cpu_offload() image = pipe( diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 9770ded5e31e..1f16146fce22 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -18,6 +18,6 @@ from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_flux import FluxTransformer2DModel from .transformer_mochi import MochiTransformer3DModel + from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_temporal import TransformerTemporalModel - from .transformer_omnigen import OmniGenTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index 64790a2024e5..5095516e57ab 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -225,15 +225,10 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): Reference: https://arxiv.org/pdf/2409.11340 Parameters: + transformer_config (`dict`): config for transformer layers. OmniGen-v1 use Phi3 as transformer backbone patch_size (`int`, defaults to 2): Patch size to turn the input data into small patches. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input. - num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. - num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. - attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. - num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. - joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. - guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. + pos_embed_max_size (`int`, *optional*, defaults to 192): The max size of pos emb. """ _supports_gradient_checkpointing = True diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 296ae5303b20..1db3d58808d2 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -25,7 +25,7 @@ from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL from ...models.embeddings import get_2d_rotary_pos_embed_lumina -from ...models.transformers.lumina_nextdit2d import LuminaextDiT2DModel +from ...models.transformers.lumina_nextdit2d import LuminaNextDiT2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( BACKENDS_MAPPING, From c78d1f4b490950bcf940df722b76e5c541f3a587 Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Fri, 6 Dec 2024 18:24:58 +0800 Subject: [PATCH 10/55] offload_transformer --- docs/source/en/using-diffusers/omnigen.md | 278 +++++++++++++++++- .../transformers/transformer_omnigen.py | 4 +- 2 files changed, 270 insertions(+), 12 deletions(-) diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md index ff398eb4f88e..22a14bd95efe 100644 --- a/docs/source/en/using-diffusers/omnigen.md +++ b/docs/source/en/using-diffusers/omnigen.md @@ -15,6 +15,7 @@ OmniGen is an image generation model. Unlike existing text-to-image models, Omni - Minimalist model architecture, consisting of only a VAE and a transformer module, for joint modeling of text and images. - Support for multimodal inputs. It can process any text-image mixed data as instructions for image generation, rather than relying solely on text. +For more information, please refer to the [paper](https://arxiv.org/pdf/2409.11340). This guide will walk you through using OmniGen for various tasks and use cases. ## Load model checkpoints @@ -30,8 +31,6 @@ pipe = OmniGenPipeline.from_pretrained( ``` -## Text-to-Image - ## Text-to-image @@ -41,30 +40,289 @@ You can try setting the `height` and `width` parameters to generate images with ```py import torch from diffusers import OmniGenPipeline + pipe = OmniGenPipeline.from_pretrained( "Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16 ) +pipe.to("cuda") -prompt = "A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD." -pipe.enable_model_cpu_offload() - +prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD." image = pipe( prompt=prompt, - generator=torch.Generator(device="cuda").manual_seed(42), + height=1024, + width=1024, + guidance_scale=3, + generator=torch.Generator(device="cpu").manual_seed(111), ).images[0] +image +``` +
+ generated image +
+ +## Image edit + +OmniGen supports for multimodal inputs. +When the input includes an image, you need to add a placeholder `<|image_1|>` in the text prompt to represent the image. +It is recommended to enable 'use_input_image_size_as_output' to keep the edited image the same size as the original image. + +```py +import torch +from diffusers import OmniGenPipeline +from diffusers.utils import load_image +pipe = OmniGenPipeline.from_pretrained( + "Shitao/OmniGen-v1-diffusers", + torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +prompt="<|image_1|> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola." +input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")] +image = pipe( + prompt=prompt, + input_images=input_images, + guidance_scale=2, + img_guidance_scale=1.6, + use_input_image_size_as_output=True, + generator=torch.Generator(device="cpu").manual_seed(222)).images[0] +image +``` +
+
+ +
original image
+
+
+ +
edited image
+
+
+ +OmniGen has some interesting features, such as the ability to infer user needs, as shown in the example below. +```py +prompt="If the woman is thirsty, what should she take? Find it in the image and highlight it in blue. <|image_1|>" +input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")] +image = pipe( + prompt=prompt, + input_images=input_images, + guidance_scale=2, + img_guidance_scale=1.6, + use_input_image_size_as_output=True, + generator=torch.Generator(device="cpu").manual_seed(0)).images[0] image ```
- generated image of an astronaut in a jungle + generated image
-For text-to-image, pass a text prompt. By default, CogVideoX generates a 720x480 video for the best results. +## Controllable generation + + OmniGen can handle several classic computer vision tasks. + As shown below, OmniGen can detect human skeletons in input images, which can be used as control conditions to generate new images. + +```py +import torch +from diffusers import OmniGenPipeline +from diffusers.utils import load_image + +pipe = OmniGenPipeline.from_pretrained( + "Shitao/OmniGen-v1-diffusers", + torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +prompt="Detect the skeleton of human in this image: <|image_1|>" +input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")] +image1 = pipe( + prompt=prompt, + input_images=input_images, + guidance_scale=2, + img_guidance_scale=1.6, + use_input_image_size_as_output=True, + generator=torch.Generator(device="cpu").manual_seed(333)).images[0] +image1 + +prompt="Generate a new photo using the following picture and text as conditions: <|image_1|>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him." +input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal.png")] +image2 = pipe( + prompt=prompt, + input_images=input_images, + guidance_scale=2, + img_guidance_scale=1.6, + use_input_image_size_as_output=True, + generator=torch.Generator(device="cpu").manual_seed(333)).images[0] +image2 +``` + +
+
+ +
original image
+
+
+ +
detected skeleton
+
+
+ +
skeleton to image
+
+
+ + +OmniGen can also directly use relevant information from input images to generate new images. +```py +import torch +from diffusers import OmniGenPipeline +from diffusers.utils import load_image + +pipe = OmniGenPipeline.from_pretrained( + "Shitao/OmniGen-v1-diffusers", + torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +prompt="Following the pose of this image <|image_1|>, generate a new photo: A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him." +input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")] +image = pipe( + prompt=prompt, + input_images=input_images, + guidance_scale=2, + img_guidance_scale=1.6, + use_input_image_size_as_output=True, + generator=torch.Generator(device="cpu").manual_seed(0)).images[0] +image +``` +
+
+ +
generated image
+
+
+ + +## ID and object preserving + +OmniGen can generate multiple images based on the people and objects in the input image and supports inputting multiple images simultaneously. +Additionally, OmniGen can extract desired objects from an image containing multiple objects based on instructions. + +```py +import torch +from diffusers import OmniGenPipeline +from diffusers.utils import load_image + +pipe = OmniGenPipeline.from_pretrained( + "Shitao/OmniGen-v1-diffusers", + torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +prompt="A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <|image_1|>. The woman is the woman on the left of <|image_2|>" +input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.jpg") +input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/4.jpg") +input_images=[input_image_1, input_image_2] +image = pipe( + prompt=prompt, + input_images=input_images, + height=1024, + width=1024, + guidance_scale=2.5, + img_guidance_scale=1.6, + generator=torch.Generator(device="cpu").manual_seed(666)).images[0] +image +``` +
+
+ +
input_image_1
+
+
+ +
input_image_2
+
+
+ +
generated image
+
+
+ + +```py +import torch +from diffusers import OmniGenPipeline +from diffusers.utils import load_image + +pipe = OmniGenPipeline.from_pretrained( + "Shitao/OmniGen-v1-diffusers", + torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + + +prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <|image_1|>. The long-sleeve blouse and a pleated skirt are <|image_2|>." +input_image_1 = load_image("/share/junjie/code/VISTA2/produce_data/laion_net/diffgpt/OmniGen/docs_img/emma.jpeg") +input_image_2 = load_image("/share/junjie/code/VISTA2/produce_data/laion_net/diffgpt/OmniGen/docs_img/dress.jpg") +input_images=[input_image_1, input_image_2] +image = pipe( + prompt=prompt, + input_images=input_images, + height=1024, + width=1024, + guidance_scale=2.5, + img_guidance_scale=1.6, + generator=torch.Generator(device="cpu").manual_seed(666)).images[0] +``` + +
+
+ +
person image
+
+
+ +
clothe image
+
+
+ +
generated image
+
+
+ + +## Optimization when inputting multiple images + +For text-to-image task, OmniGen requires minimal memory and time costs (9G memory and 31s for a 1024*1024 image on A800 GPU). +However, when using input images, the computational cost increases. + +Here are some guidelines to help you reduce computational costs when input multiple images. The experiments are conducted on A800 GPU and input two images to OmniGen. -## Optimization ### inference speed -### Memory \ No newline at end of file +- `use_kv_cache=True`: + `use_kv_cache` will store key and value states of the input conditions to compute attention without redundant computations. + The default value is True, and OmniGen will offload the kv cache to cpu default. + - `use_kv_cache=False`: the inference time is 3m21s. + - `use_kv_cache=True`: the inference time is 1m30s. + +- `max_input_image_size`: + the maximum size of input image, which will be used to crop the input image + - `max_input_image_size=1024`: the inference time is 1m30s. + - `max_input_image_size=512`: the inference time is 58s. + +### Memory + +- `pipe.enable_model_cpu_offload()`: + - Without enabling cpu offloading, memory usage is `31 GB` + - With enabling cpu offloading, memory usage is `28 GB` + +- `offload_transformer_block=True`: + - 17G + +- `pipe.enable_sequential_cpu_offload()`: + - 11G + + diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index 5095516e57ab..eb61dfc3e592 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -174,7 +174,7 @@ def forward( ) else: if offload_transformer_block and not self.training: - if not not torch.cuda.is_available(): + if not torch.cuda.is_available(): logger.warning_once( "We don't detecte any available GPU, so diable `offload_transformer_block`" ) @@ -363,7 +363,7 @@ def get_multimodal_embeddings(self, input_img_inx = 0 if input_img_latents is not None: input_image_tokens = self.patch_embedding(input_img_latents, - is_input_image=True) + is_input_image=True) for b_inx in input_image_sizes.keys(): for start_inx, end_inx in input_image_sizes[b_inx]: From 236f14b703d933092c088095ec03f60e94e8d4b1 Mon Sep 17 00:00:00 2001 From: staoxiao <2906698981@qq.com> Date: Sun, 8 Dec 2024 15:00:28 +0800 Subject: [PATCH 11/55] enable_transformer_block_cpu_offload --- .../transformers/transformer_omnigen.py | 13 +++++------- .../pipelines/omnigen/pipeline_omnigen.py | 20 +++++++++++++++++-- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index eb61dfc3e592..882065faae3d 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -74,21 +74,18 @@ def evict_previous_layer(self, layer_idx: int): prev_layer_idx = layer_idx - 1 for name, param in self.layers[prev_layer_idx].named_parameters(): param.data = param.data.to("cpu", non_blocking=True) - + def get_offload_layer(self, layer_idx: int, device: torch.device): # init stream if not hasattr(self, "prefetch_stream"): self.prefetch_stream = torch.cuda.Stream() # delete previous layer - # main stream sync shouldn't be necessary since all computation on iter i-1 is finished by iter i - # torch.cuda.current_stream().synchronize() - # avoid extra eviction of last layer - if layer_idx > 0: - self.evict_previous_layer(layer_idx) - + torch.cuda.current_stream().synchronize() + self.evict_previous_layer(layer_idx) + # make sure the current layer is ready - self.prefetch_stream.synchronize() + torch.cuda.synchronize(self.prefetch_stream) # load next layer self.prefetch_layer((layer_idx + 1) % len(self.layers), device) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index ff10cade50f3..902453aae0f6 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -327,6 +327,18 @@ def num_timesteps(self): @property def interrupt(self): return self._interrupt + + def enable_transformer_block_cpu_offload(self, device: Union[torch.device, str] = "cuda"): + torch_device = torch.device(device) + for name, param in self.transformer.named_parameters(): + if 'layers' in name and 'layers.0' not in name: + param.data = param.data.cpu() + else: + param.data = param.data.to(torch_device) + for buffer_name, buffer in self.transformer.patch_embedding.named_buffers(): + setattr(self.transformer.patch_embedding, buffer_name, buffer.to(torch_device)) + self.vae.to(torch_device) + self.offload_transformer_block = True @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -440,6 +452,9 @@ def __call__( # using Float32 for the VAE doesn't take up much memory but can prevent potential black image outputs. self.vae.to(torch.float32) + if offload_transformer_block: + self.enable_transformer_block_cpu_offload() + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -460,9 +475,10 @@ def __call__( batch_size = len(prompt) device = self._execution_device + # 3. process multi-modal instructions if max_input_image_size != self.multimodal_processor.max_image_size: - self.multimodal_processor = OmniGenMultiModalProcessor(self.text_tokenizer, max_image_size=max_input_image_size) + self.multimodal_processor = OmniGenMultiModalProcessor(self.tokenizer, max_image_size=max_input_image_size) processed_data = self.multimodal_processor(prompt, input_images, height=height, @@ -521,7 +537,7 @@ def __call__( position_ids=processed_data['position_ids'], attention_kwargs=attention_kwargs, past_key_values=cache, - offload_transformer_block=offload_transformer_block, + offload_transformer_block=self.offload_transformer_block if hasattr(self, 'offload_transformer_block') else offload_transformer_block, return_dict=False, ) From 6b52547ebcf83ea564fd59801c80affe5b53d944 Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Sun, 8 Dec 2024 16:06:38 +0800 Subject: [PATCH 12/55] update docs --- docs/source/en/api/pipelines/omnigen.md | 66 +- .../en/using-diffusers/multimodal2img.md | 612 ++---------------- docs/source/en/using-diffusers/omnigen.md | 7 +- .../transformers/transformer_omnigen.py | 58 +- .../pipelines/omnigen/kvcache_omnigen.py | 54 +- .../pipelines/omnigen/pipeline_omnigen.py | 155 +++-- .../pipelines/omnigen/processor_omnigen.py | 16 +- 7 files changed, 268 insertions(+), 700 deletions(-) diff --git a/docs/source/en/api/pipelines/omnigen.md b/docs/source/en/api/pipelines/omnigen.md index fe0bc5d496a0..b7ea5488156a 100644 --- a/docs/source/en/api/pipelines/omnigen.md +++ b/docs/source/en/api/pipelines/omnigen.md @@ -56,44 +56,50 @@ First, load the pipeline: ```python import torch -from diffusers import CogVideoXPipeline, CogVideoXImageToVideoPipeline -from diffusers.utils import export_to_video,load_image -pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b").to("cuda") # or "THUDM/CogVideoX-2b" +from diffusers import OmniGenPipeline +pipe = OmniGenPipeline.from_pretrained( + "Shitao/OmniGen-v1-diffusers", + torch_dtype=torch.bfloat16 +) +pipe.to("cuda") ``` -If you are using the image-to-video pipeline, load it as follows: - -```python -pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V").to("cuda") +For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image. +You can try setting the `height` and `width` parameters to generate images with different size. + +```py +prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD." +image = pipe( + prompt=prompt, + height=1024, + width=1024, + guidance_scale=3, + generator=torch.Generator(device="cpu").manual_seed(111), +).images[0] +image ``` -Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`: - -```python -pipe.transformer.to(memory_format=torch.channels_last) -``` - -Compile the components and run inference: - -```python -pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True) - -# CogVideoX works well with long and well-described prompts -prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance." -video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] -``` - -The [T2V benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are: - -``` -Without torch.compile(): Average inference time: 96.89 seconds. -With torch.compile(): Average inference time: 76.27 seconds. +OmniGen supports for multimodal inputs. +When the input includes an image, you need to add a placeholder `<|image_1|>` in the text prompt to represent the image. +It is recommended to enable 'use_input_image_size_as_output' to keep the edited image the same size as the original image. + +```py +prompt="<|image_1|> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola." +input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")] +image = pipe( + prompt=prompt, + input_images=input_images, + guidance_scale=2, + img_guidance_scale=1.6, + use_input_image_size_as_output=True, + generator=torch.Generator(device="cpu").manual_seed(222)).images[0] +image ``` -## CogVideoXPipeline +## OmniGenPipeline -[[autodoc]] CogVideoXPipeline +[[autodoc]] OmniGenPipeline - all - __call__ diff --git a/docs/source/en/using-diffusers/multimodal2img.md b/docs/source/en/using-diffusers/multimodal2img.md index 4618731830df..10567d49ae19 100644 --- a/docs/source/en/using-diffusers/multimodal2img.md +++ b/docs/source/en/using-diffusers/multimodal2img.md @@ -10,596 +10,104 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Image-to-image +# Multi-modal instruction to image [[open-in-colab]] -Image-to-image is similar to [text-to-image](conditional_image_generation), but in addition to a prompt, you can also pass an initial image as a starting point for the diffusion process. The initial image is encoded to latent space and noise is added to it. Then the latent diffusion model takes a prompt and the noisy latent image, predicts the added noise, and removes the predicted noise from the initial latent image to get the new latent image. Lastly, a decoder decodes the new latent image back into an image. -With 🤗 Diffusers, this is as easy as 1-2-3: -1. Load a checkpoint into the [`AutoPipelineForImage2Image`] class; this pipeline automatically handles loading the correct pipeline class based on the checkpoint: +Multimodal instructions mean you can input any sequence of mixed text and images to guide image generation. You can input multiple images and use prompts to describe the desired output. This approach is more flexible than using only text or images. -```py -import torch -from diffusers import AutoPipelineForImage2Image -from diffusers.utils import load_image, make_image_grid - -pipeline = AutoPipelineForImage2Image.from_pretrained( - "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True -) -pipeline.enable_model_cpu_offload() -# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed -pipeline.enable_xformers_memory_efficient_attention() -``` - - - -You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, then you don't need to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention). - - - -2. Load an image to pass to the pipeline: - -```py -init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png") -``` - -3. Pass a prompt and image to the pipeline to generate an image: - -```py -prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" -image = pipeline(prompt, image=init_image).images[0] -make_image_grid([init_image, image], rows=1, cols=2) -``` - -
-
- -
initial image
-
-
- -
generated image
-
-
- -## Popular models - -The most popular image-to-image models are [Stable Diffusion v1.5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5), [Stable Diffusion XL (SDXL)](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), and [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder). The results from the Stable Diffusion and Kandinsky models vary due to their architecture differences and training process; you can generally expect SDXL to produce higher quality images than Stable Diffusion v1.5. Let's take a quick look at how to use each of these models and compare their results. +## Examples -### Stable Diffusion v1.5 -Stable Diffusion v1.5 is a latent diffusion model initialized from an earlier checkpoint, and further finetuned for 595K steps on 512x512 images. To use this pipeline for image-to-image, you'll need to prepare an initial image to pass to the pipeline. Then you can pass a prompt and the image to the pipeline to generate a new image: +Take `OmniGenPipeline` as an example: the input can be a text-image sequence to create new images, he input can be a text-image sequence, with images inserted into the text prompt via special placeholder `<|image_i|>`. -```py +```py import torch -from diffusers import AutoPipelineForImage2Image -from diffusers.utils import make_image_grid, load_image - -pipeline = AutoPipelineForImage2Image.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True -) -pipeline.enable_model_cpu_offload() -# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed -pipeline.enable_xformers_memory_efficient_attention() - -# prepare image -url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" -init_image = load_image(url) - -prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" - -# pass prompt and image to pipeline -image = pipeline(prompt, image=init_image).images[0] -make_image_grid([init_image, image], rows=1, cols=2) +from diffusers import OmniGenPipeline +from diffusers.utils import load_image + +pipe = OmniGenPipeline.from_pretrained( + "Shitao/OmniGen-v1-diffusers", + torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +prompt="A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <|image_1|>. The woman is the woman on the left of <|image_2|>" +input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.jpg") +input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/4.jpg") +input_images=[input_image_1, input_image_2] +image = pipe( + prompt=prompt, + input_images=input_images, + height=1024, + width=1024, + guidance_scale=2.5, + img_guidance_scale=1.6, + generator=torch.Generator(device="cpu").manual_seed(666)).images[0] +image ``` - -
-
- -
initial image
-
-
- -
generated image
-
-
- -### Stable Diffusion XL (SDXL) - -SDXL is a more powerful version of the Stable Diffusion model. It uses a larger base model, and an additional refiner model to increase the quality of the base model's output. Read the [SDXL](sdxl) guide for a more detailed walkthrough of how to use this model, and other techniques it uses to produce high quality images. - -```py -import torch -from diffusers import AutoPipelineForImage2Image -from diffusers.utils import make_image_grid, load_image - -pipeline = AutoPipelineForImage2Image.from_pretrained( - "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True -) -pipeline.enable_model_cpu_offload() -# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed -pipeline.enable_xformers_memory_efficient_attention() - -# prepare image -url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png" -init_image = load_image(url) - -prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" - -# pass prompt and image to pipeline -image = pipeline(prompt, image=init_image, strength=0.5).images[0] -make_image_grid([init_image, image], rows=1, cols=2) -``` - -
-
- -
initial image
-
-
- -
generated image
-
-
- -### Kandinsky 2.2 - -The Kandinsky model is different from the Stable Diffusion models because it uses an image prior model to create image embeddings. The embeddings help create a better alignment between text and images, allowing the latent diffusion model to generate better images. - -The simplest way to use Kandinsky 2.2 is: - -```py -import torch -from diffusers import AutoPipelineForImage2Image -from diffusers.utils import make_image_grid, load_image - -pipeline = AutoPipelineForImage2Image.from_pretrained( - "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True -) -pipeline.enable_model_cpu_offload() -# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed -pipeline.enable_xformers_memory_efficient_attention() - -# prepare image -url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" -init_image = load_image(url) - -prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" - -# pass prompt and image to pipeline -image = pipeline(prompt, image=init_image).images[0] -make_image_grid([init_image, image], rows=1, cols=2) -``` - -
-
- -
initial image
-
-
- -
generated image
-
-
- -## Configure pipeline parameters - -There are several important parameters you can configure in the pipeline that'll affect the image generation process and image quality. Let's take a closer look at what these parameters do and how changing them affects the output. - -### Strength - -`strength` is one of the most important parameters to consider and it'll have a huge impact on your generated image. It determines how much the generated image resembles the initial image. In other words: - -- 📈 a higher `strength` value gives the model more "creativity" to generate an image that's different from the initial image; a `strength` value of 1.0 means the initial image is more or less ignored -- 📉 a lower `strength` value means the generated image is more similar to the initial image - -The `strength` and `num_inference_steps` parameters are related because `strength` determines the number of noise steps to add. For example, if the `num_inference_steps` is 50 and `strength` is 0.8, then this means adding 40 (50 * 0.8) steps of noise to the initial image and then denoising for 40 steps to get the newly generated image. - -```py -import torch -from diffusers import AutoPipelineForImage2Image -from diffusers.utils import make_image_grid, load_image - -pipeline = AutoPipelineForImage2Image.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True -) -pipeline.enable_model_cpu_offload() -# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed -pipeline.enable_xformers_memory_efficient_attention() - -# prepare image -url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" -init_image = load_image(url) - -prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" - -# pass prompt and image to pipeline -image = pipeline(prompt, image=init_image, strength=0.8).images[0] -make_image_grid([init_image, image], rows=1, cols=2) -``` - -
-
- -
strength = 0.4
-
-
- -
strength = 0.6
-
-
- -
strength = 1.0
-
-
- -### Guidance scale - -The `guidance_scale` parameter is used to control how closely aligned the generated image and text prompt are. A higher `guidance_scale` value means your generated image is more aligned with the prompt, while a lower `guidance_scale` value means your generated image has more space to deviate from the prompt. - -You can combine `guidance_scale` with `strength` for even more precise control over how expressive the model is. For example, combine a high `strength + guidance_scale` for maximum creativity or use a combination of low `strength` and low `guidance_scale` to generate an image that resembles the initial image but is not as strictly bound to the prompt. - -```py -import torch -from diffusers import AutoPipelineForImage2Image -from diffusers.utils import make_image_grid, load_image - -pipeline = AutoPipelineForImage2Image.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True -) -pipeline.enable_model_cpu_offload() -# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed -pipeline.enable_xformers_memory_efficient_attention() - -# prepare image -url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" -init_image = load_image(url) - -prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" - -# pass prompt and image to pipeline -image = pipeline(prompt, image=init_image, guidance_scale=8.0).images[0] -make_image_grid([init_image, image], rows=1, cols=2) -``` -
- -
guidance_scale = 0.1
+ +
input_image_1
- -
guidance_scale = 5.0
+ +
input_image_2
- -
guidance_scale = 10.0
-
-
- -### Negative prompt - -A negative prompt conditions the model to *not* include things in an image, and it can be used to improve image quality or modify an image. For example, you can improve image quality by including negative prompts like "poor details" or "blurry" to encourage the model to generate a higher quality image. Or you can modify an image by specifying things to exclude from an image. - -```py -import torch -from diffusers import AutoPipelineForImage2Image -from diffusers.utils import make_image_grid, load_image - -pipeline = AutoPipelineForImage2Image.from_pretrained( - "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True -) -pipeline.enable_model_cpu_offload() -# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed -pipeline.enable_xformers_memory_efficient_attention() - -# prepare image -url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" -init_image = load_image(url) - -prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" -negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy" - -# pass prompt and image to pipeline -image = pipeline(prompt, negative_prompt=negative_prompt, image=init_image).images[0] -make_image_grid([init_image, image], rows=1, cols=2) -``` - -
-
- -
negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
-
-
- -
negative_prompt = "jungle"
+ +
generated image
-## Chained image-to-image pipelines - -There are some other interesting ways you can use an image-to-image pipeline aside from just generating an image (although that is pretty cool too). You can take it a step further and chain it with other pipelines. - -### Text-to-image-to-image - -Chaining a text-to-image and image-to-image pipeline allows you to generate an image from text and use the generated image as the initial image for the image-to-image pipeline. This is useful if you want to generate an image entirely from scratch. For example, let's chain a Stable Diffusion and a Kandinsky model. - -Start by generating an image with the text-to-image pipeline: - -```py -from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image -import torch -from diffusers.utils import make_image_grid - -pipeline = AutoPipelineForText2Image.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True -) -pipeline.enable_model_cpu_offload() -# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed -pipeline.enable_xformers_memory_efficient_attention() - -text2image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k").images[0] -text2image -``` - -Now you can pass this generated image to the image-to-image pipeline: - -```py -pipeline = AutoPipelineForImage2Image.from_pretrained( - "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True -) -pipeline.enable_model_cpu_offload() -# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed -pipeline.enable_xformers_memory_efficient_attention() - -image2image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=text2image).images[0] -make_image_grid([text2image, image2image], rows=1, cols=2) -``` - -### Image-to-image-to-image - -You can also chain multiple image-to-image pipelines together to create more interesting images. This can be useful for iteratively performing style transfer on an image, generating short GIFs, restoring color to an image, or restoring missing areas of an image. - -Start by generating an image: - -```py -import torch -from diffusers import AutoPipelineForImage2Image -from diffusers.utils import make_image_grid, load_image - -pipeline = AutoPipelineForImage2Image.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True -) -pipeline.enable_model_cpu_offload() -# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed -pipeline.enable_xformers_memory_efficient_attention() - -# prepare image -url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" -init_image = load_image(url) - -prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" - -# pass prompt and image to pipeline -image = pipeline(prompt, image=init_image, output_type="latent").images[0] -``` - - - -It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE. - - - -Pass the latent output from this pipeline to the next pipeline to generate an image in a [comic book art style](https://huggingface.co/ogkalu/Comic-Diffusion): - -```py -pipeline = AutoPipelineForImage2Image.from_pretrained( - "ogkalu/Comic-Diffusion", torch_dtype=torch.float16 -) -pipeline.enable_model_cpu_offload() -# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed -pipeline.enable_xformers_memory_efficient_attention() - -# need to include the token "charliebo artstyle" in the prompt to use this checkpoint -image = pipeline("Astronaut in a jungle, charliebo artstyle", image=image, output_type="latent").images[0] -``` - -Repeat one more time to generate the final image in a [pixel art style](https://huggingface.co/kohbanye/pixel-art-style): - -```py -pipeline = AutoPipelineForImage2Image.from_pretrained( - "kohbanye/pixel-art-style", torch_dtype=torch.float16 -) -pipeline.enable_model_cpu_offload() -# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed -pipeline.enable_xformers_memory_efficient_attention() - -# need to include the token "pixelartstyle" in the prompt to use this checkpoint -image = pipeline("Astronaut in a jungle, pixelartstyle", image=image).images[0] -make_image_grid([init_image, image], rows=1, cols=2) -``` - -### Image-to-upscaler-to-super-resolution - -Another way you can chain your image-to-image pipeline is with an upscaler and super-resolution pipeline to really increase the level of details in an image. - -Start with an image-to-image pipeline: - -```py -import torch -from diffusers import AutoPipelineForImage2Image -from diffusers.utils import make_image_grid, load_image - -pipeline = AutoPipelineForImage2Image.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True -) -pipeline.enable_model_cpu_offload() -# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed -pipeline.enable_xformers_memory_efficient_attention() - -# prepare image -url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" -init_image = load_image(url) - -prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" - -# pass prompt and image to pipeline -image_1 = pipeline(prompt, image=init_image, output_type="latent").images[0] -``` - - - -It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in *latent* space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE. - - - -Chain it to an upscaler pipeline to increase the image resolution: - -```py -from diffusers import StableDiffusionLatentUpscalePipeline - -upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained( - "stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True -) -upscaler.enable_model_cpu_offload() -upscaler.enable_xformers_memory_efficient_attention() - -image_2 = upscaler(prompt, image=image_1, output_type="latent").images[0] -``` - -Finally, chain it to a super-resolution pipeline to further enhance the resolution: - -```py -from diffusers import StableDiffusionUpscalePipeline - -super_res = StableDiffusionUpscalePipeline.from_pretrained( - "stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True -) -super_res.enable_model_cpu_offload() -super_res.enable_xformers_memory_efficient_attention() - -image_3 = super_res(prompt, image=image_2).images[0] -make_image_grid([init_image, image_3.resize((512, 512))], rows=1, cols=2) -``` - -## Control image generation - -Trying to generate an image that looks exactly the way you want can be difficult, which is why controlled generation techniques and models are so useful. While you can use the `negative_prompt` to partially control image generation, there are more robust methods like prompt weighting and ControlNets. - -### Prompt weighting - -Prompt weighting allows you to scale the representation of each concept in a prompt. For example, in a prompt like "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", you can choose to increase or decrease the embeddings of "astronaut" and "jungle". The [Compel](https://github.com/damian0815/compel) library provides a simple syntax for adjusting prompt weights and generating the embeddings. You can learn how to create the embeddings in the [Prompt weighting](weighted_prompts) guide. - -[`AutoPipelineForImage2Image`] has a `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter where you can pass the embeddings which replaces the `prompt` parameter. - -```py -from diffusers import AutoPipelineForImage2Image -import torch - -pipeline = AutoPipelineForImage2Image.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True -) -pipeline.enable_model_cpu_offload() -# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed -pipeline.enable_xformers_memory_efficient_attention() - -image = pipeline(prompt_embeds=prompt_embeds, # generated from Compel - negative_prompt_embeds=negative_prompt_embeds, # generated from Compel - image=init_image, -).images[0] -``` - -### ControlNet - -ControlNets provide a more flexible and accurate way to control image generation because you can use an additional conditioning image. The conditioning image can be a canny image, depth map, image segmentation, and even scribbles! Whatever type of conditioning image you choose, the ControlNet generates an image that preserves the information in it. - -For example, let's condition an image with a depth map to keep the spatial information in the image. ```py -from diffusers.utils import load_image, make_image_grid - -# prepare image -url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png" -init_image = load_image(url) -init_image = init_image.resize((958, 960)) # resize to depth image dimensions -depth_image = load_image("https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png") -make_image_grid([init_image, depth_image], rows=1, cols=2) -``` - -Load a ControlNet model conditioned on depth maps and the [`AutoPipelineForImage2Image`]: - -```py -from diffusers import ControlNetModel, AutoPipelineForImage2Image import torch +from diffusers import OmniGenPipeline +from diffusers.utils import load_image -controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16, variant="fp16", use_safetensors=True) -pipeline = AutoPipelineForImage2Image.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16", use_safetensors=True +pipe = OmniGenPipeline.from_pretrained( + "Shitao/OmniGen-v1-diffusers", + torch_dtype=torch.bfloat16 ) -pipeline.enable_model_cpu_offload() -# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed -pipeline.enable_xformers_memory_efficient_attention() -``` +pipe.to("cuda") -Now generate a new image conditioned on the depth map, initial image, and prompt: -```py -prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" -image_control_net = pipeline(prompt, image=init_image, control_image=depth_image).images[0] -make_image_grid([init_image, depth_image, image_control_net], rows=1, cols=3) +prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <|image_1|>. The long-sleeve blouse and a pleated skirt are <|image_2|>." +input_image_1 = load_image("/share/junjie/code/VISTA2/produce_data/laion_net/diffgpt/OmniGen/docs_img/emma.jpeg") +input_image_2 = load_image("/share/junjie/code/VISTA2/produce_data/laion_net/diffgpt/OmniGen/docs_img/dress.jpg") +input_images=[input_image_1, input_image_2] +image = pipe( + prompt=prompt, + input_images=input_images, + height=1024, + width=1024, + guidance_scale=2.5, + img_guidance_scale=1.6, + generator=torch.Generator(device="cpu").manual_seed(666)).images[0] ```
- -
initial image
+ +
person image
- -
depth image
+ +
clothe image
- -
ControlNet image
+ +
generated image
-Let's apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion) to the image generated from the ControlNet by chaining it with an image-to-image pipeline: -```py -pipeline = AutoPipelineForImage2Image.from_pretrained( - "nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16, -) -pipeline.enable_model_cpu_offload() -# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed -pipeline.enable_xformers_memory_efficient_attention() - -prompt = "elden ring style astronaut in a jungle" # include the token "elden ring style" in the prompt -negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy" - -image_elden_ring = pipeline(prompt, negative_prompt=negative_prompt, image=image_control_net, strength=0.45, guidance_scale=10.5).images[0] -make_image_grid([init_image, depth_image, image_control_net, image_elden_ring], rows=2, cols=2) -``` - -
- -
- -## Optimize - -Running diffusion models is computationally expensive and intensive, but with a few optimization tricks, it is entirely possible to run them on consumer and free-tier GPUs. For example, you can use a more memory-efficient form of attention such as PyTorch 2.0's [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) or [xFormers](../optimization/xformers) (you can use one or the other, but there's no need to use both). You can also offload the model to the GPU while the other pipeline components wait on the CPU. - -```diff -+ pipeline.enable_model_cpu_offload() -+ pipeline.enable_xformers_memory_efficient_attention() -``` - -With [`torch.compile`](../optimization/torch2.0#torchcompile), you can boost your inference speed even more by wrapping your UNet with it: +The output image is a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object that can be saved: ```py -pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True) -``` - -To learn more, take a look at the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides. +image.save("generated_image.png") +``` \ No newline at end of file diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md index 22a14bd95efe..c535f4358155 100644 --- a/docs/source/en/using-diffusers/omnigen.md +++ b/docs/source/en/using-diffusers/omnigen.md @@ -320,9 +320,10 @@ Here are some guidelines to help you reduce computational costs when input multi - With enabling cpu offloading, memory usage is `28 GB` - `offload_transformer_block=True`: - - 17G + - offload transformer block to reduce memory usage + - When enabled, memory usage is under `25 GB` - `pipe.enable_sequential_cpu_offload()`: - - 11G - + - significantly reduce memory usage at the cost of slow inference + - When enabled, memory usage is under `11 GB` diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index 882065faae3d..3a1072a0d349 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union, List -from dataclasses import dataclass import torch import torch.utils.checkpoint from torch import nn -from transformers.cache_utils import DynamicCache from transformers import Phi3Model, Phi3Config from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast @@ -36,7 +35,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name - @dataclass class OmniGen2DModelOutput(Transformer2DModelOutput): """ @@ -74,7 +72,7 @@ def evict_previous_layer(self, layer_idx: int): prev_layer_idx = layer_idx - 1 for name, param in self.layers[prev_layer_idx].named_parameters(): param.data = param.data.to("cpu", non_blocking=True) - + def get_offload_layer(self, layer_idx: int, device: torch.device): # init stream if not hasattr(self, "prefetch_stream"): @@ -83,7 +81,7 @@ def get_offload_layer(self, layer_idx: int, device: torch.device): # delete previous layer torch.cuda.current_stream().synchronize() self.evict_previous_layer(layer_idx) - + # make sure the current layer is ready torch.cuda.synchronize(self.prefetch_stream) @@ -273,7 +271,6 @@ def unpatchify(self, x, h, w): imgs = x.reshape(shape=(x.shape[0], c, h, w)) return imgs - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -337,12 +334,12 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value - - def get_multimodal_embeddings(self, - input_ids: torch.Tensor, - input_img_latents: List[torch.Tensor], - input_image_sizes: Dict, - ): + + def get_multimodal_embeddings(self, + input_ids: torch.Tensor, + input_img_latents: List[torch.Tensor], + input_image_sizes: Dict, + ): """ get the multi-modal conditional embeddings Args: @@ -353,7 +350,7 @@ def get_multimodal_embeddings(self, Returns: torch.Tensor """ - input_img_latents = [x.to(self.dtype) for x in input_img_latents] + input_img_latents = [x.to(self.dtype) for x in input_img_latents] condition_tokens = None if input_ids is not None: condition_tokens = self.llm.embed_tokens(input_ids) @@ -384,6 +381,41 @@ def forward(self, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ): + """ + The [`OmniGenTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + timestep (`torch.LongTensor`): + Used to indicate denoising step. + input_ids (`torch.LongTensor`): + token ids + input_img_latents (`torch.FloatTensor`): + encoded image latents by VAE + input_image_sizes (`dict`): + the indices of the input_img_latents in the input_ids + attention_mask (`torch.FloatTensor`): + mask for self-attention + position_ids (`torch.LongTensor`): + id to represent position + past_key_values (`transformers.cache_utils.Cache`): + previous key and value states + offload_transformer_block (`bool`, *optional*, defaults to `True`): + offload transformer block to cpu + attention_kwargs: (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`OmniGen2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`OmniGen2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + + """ if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() diff --git a/src/diffusers/pipelines/omnigen/kvcache_omnigen.py b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py index 7f02588ce405..4bf32ae6ae74 100644 --- a/src/diffusers/pipelines/omnigen/kvcache_omnigen.py +++ b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py @@ -1,20 +1,32 @@ -from tqdm import tqdm +# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Optional, Dict, Any, Tuple, List -import gc import torch -from transformers.cache_utils import Cache, DynamicCache, OffloadedCache - +from transformers.cache_utils import DynamicCache class OmniGenCache(DynamicCache): - def __init__(self, - num_tokens_for_img: int, - offload_kv_cache: bool=False) -> None: + def __init__(self, + num_tokens_for_img: int, + offload_kv_cache: bool = False) -> None: if not torch.cuda.is_available(): # print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!") # offload_kv_cache = False - raise RuntimeError("OffloadedCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!") + raise RuntimeError( + "OffloadedCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!") super().__init__() self.original_device = [] self.prefetch_stream = torch.cuda.Stream() @@ -30,19 +42,17 @@ def prefetch_layer(self, layer_idx: int): self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True) - def evict_previous_layer(self, layer_idx: int): "Moves the previous layer cache to the CPU" if len(self) > 2: # We do it on the default stream so it occurs after all earlier computations on these tensors are done - if layer_idx == 0: + if layer_idx == 0: prev_layer_idx = -1 else: prev_layer_idx = (layer_idx - 1) % len(self) self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True) self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True) - def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." if layer_idx < len(self): @@ -56,7 +66,7 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: torch.cuda.synchronize(self.prefetch_stream) key_tensor = self.key_cache[layer_idx] value_tensor = self.value_cache[layer_idx] - + # Prefetch the next layer self.prefetch_layer((layer_idx + 1) % len(self)) else: @@ -65,13 +75,13 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: return (key_tensor, value_tensor) else: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") - + def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. @@ -92,13 +102,13 @@ def update( raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") elif len(self.key_cache) == layer_idx: # only cache the states for condition tokens - key_states = key_states[..., :-(self.num_tokens_for_img+1), :] - value_states = value_states[..., :-(self.num_tokens_for_img+1), :] + key_states = key_states[..., :-(self.num_tokens_for_img + 1), :] + value_states = value_states[..., :-(self.num_tokens_for_img + 1), :] - # Update the number of seen tokens + # Update the number of seen tokens if layer_idx == 0: self._seen_tokens += key_states.shape[-2] - + self.key_cache.append(key_states) self.value_cache.append(value_states) self.original_device.append(key_states.device) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 902453aae0f6..603ef39476e8 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -32,17 +32,15 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from .processor_omnigen import OmniGenMultiModalProcessor from .kvcache_omnigen import OmniGenCache +from .processor_omnigen import OmniGenMultiModalProcessor if is_torch_xla_available(): - import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False - logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ @@ -62,15 +60,14 @@ """ - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, ): r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles @@ -150,11 +147,11 @@ class OmniGenPipeline( _callback_tensor_inputs = ["latents", "input_images_latents"] def __init__( - self, - transformer: OmniGenTransformer2DModel, - scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKL, - tokenizer: LlamaTokenizer, + self, + transformer: OmniGenTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + tokenizer: LlamaTokenizer, ): super().__init__() @@ -199,17 +196,16 @@ def encod_input_iamges( input_img_latents.append(img) return input_img_latents - def check_inputs( - self, - prompt, - input_images, - height, - width, - use_kv_cache, - offload_kv_cache, - callback_on_step_end_tensor_inputs=None, - max_sequence_length=None, + self, + prompt, + input_images, + height, + width, + use_kv_cache, + offload_kv_cache, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, ): if input_images is not None: @@ -219,7 +215,7 @@ def check_inputs( ) for i in range(len(input_images)): if input_images[i] is not None: - if not all(f"<|image_{k+1}|>" in prompt[i] for k in range(len(input_images[i]))): + if not all(f"<|image_{k + 1}|>" in prompt[i] for k in range(len(input_images[i]))): raise ValueError( f"prompt `{prompt[i]}` doesn't have enough placeholders for the input images `{input_images[i]}`" ) @@ -228,15 +224,15 @@ def check_inputs( logger.warning( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) - + if use_kv_cache and offload_kv_cache: if not torch.cuda.is_available(): raise ValueError( - f"Don't fine avaliable GPUs. `offload_kv_cache` can't be used when there is no GPU. please set it to False: `use_kv_cache=False, offload_kv_cache=False`" - ) + f"Don't fine avaliable GPUs. `offload_kv_cache` can't be used when there is no GPU. please set it to False: `use_kv_cache=False, offload_kv_cache=False`" + ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -286,15 +282,15 @@ def disable_vae_tiling(self): self.vae.disable_tiling() def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, ): if latents is not None: return latents.to(device=device, dtype=dtype) @@ -327,7 +323,7 @@ def num_timesteps(self): @property def interrupt(self): return self._interrupt - + def enable_transformer_block_cpu_offload(self, device: Union[torch.device, str] = "cuda"): torch_device = torch.device(device) for name, param in self.transformer.named_parameters(): @@ -343,29 +339,30 @@ def enable_transformer_block_cpu_offload(self, device: Union[torch.device, str] @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( - self, - prompt: Union[str, List[str]], - input_images: Optional[Union[List[str], List[PIL.Image.Image], List[List[str]], List[List[PIL.Image.Image]]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - max_input_image_size: int = 1024, - timesteps: List[int] = None, - guidance_scale: float = 2.5, - img_guidance_scale: float = 1.6, - use_kv_cache: bool = True, - offload_kv_cache: bool = True, - offload_transformer_block: bool = False, - use_input_image_size_as_output: bool = False, - num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 120000, + self, + prompt: Union[str, List[str]], + input_images: Optional[ + Union[List[str], List[PIL.Image.Image], List[List[str]], List[List[PIL.Image.Image]]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + max_input_image_size: int = 1024, + timesteps: List[int] = None, + guidance_scale: float = 2.5, + img_guidance_scale: float = 1.6, + use_kv_cache: bool = True, + offload_kv_cache: bool = True, + offload_transformer_block: bool = False, + use_input_image_size_as_output: bool = False, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 120000, ): r""" Function invoked when calling the pipeline for generation. @@ -444,11 +441,11 @@ def __call__( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor num_cfg = 2 if input_images is not None else 1 - use_img_cfg = True if input_images is not None else False + use_img_cfg = True if input_images is not None else False if isinstance(prompt, str): prompt = [prompt] input_images = [input_images] - + # using Float32 for the VAE doesn't take up much memory but can prevent potential black image outputs. self.vae.to(torch.float32) @@ -475,16 +472,15 @@ def __call__( batch_size = len(prompt) device = self._execution_device - # 3. process multi-modal instructions if max_input_image_size != self.multimodal_processor.max_image_size: self.multimodal_processor = OmniGenMultiModalProcessor(self.tokenizer, max_image_size=max_input_image_size) processed_data = self.multimodal_processor(prompt, - input_images, - height=height, - width=width, - use_img_cfg=use_img_cfg, - use_input_image_size_as_output=use_input_image_size_as_output) + input_images, + height=height, + width=width, + use_img_cfg=use_img_cfg, + use_input_image_size_as_output=use_input_image_size_as_output) processed_data['input_ids'] = processed_data['input_ids'].to(device) processed_data['attention_mask'] = processed_data['attention_mask'].to(device) processed_data['position_ids'] = processed_data['position_ids'].to(device) @@ -493,7 +489,7 @@ def __call__( input_img_latents = self.encod_input_iamges(processed_data['input_pixel_values'], device=device) # 5. Prepare timesteps - sigmas = np.linspace(1, 0, num_inference_steps+1)[:num_inference_steps] + sigmas = np.linspace(1, 0, num_inference_steps + 1)[:num_inference_steps] timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas ) @@ -522,7 +518,7 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * (num_cfg+1)) + latent_model_input = torch.cat([latents] * (num_cfg + 1)) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) @@ -537,16 +533,20 @@ def __call__( position_ids=processed_data['position_ids'], attention_kwargs=attention_kwargs, past_key_values=cache, - offload_transformer_block=self.offload_transformer_block if hasattr(self, 'offload_transformer_block') else offload_transformer_block, + offload_transformer_block=self.offload_transformer_block if hasattr(self, + 'offload_transformer_block') else offload_transformer_block, return_dict=False, ) - + # if use kv cache, don't need attention mask and position ids of condition tokens for next step if use_kv_cache: if processed_data['input_ids'] is not None: processed_data['input_ids'] = None - processed_data['attention_mask'] = processed_data['attention_mask'][..., -(num_tokens_for_output_img + 1):, :] # +1 is for the timestep token - processed_data['position_ids'] = processed_data['position_ids'][:, -(num_tokens_for_output_img + 1):] + processed_data['attention_mask'] = processed_data['attention_mask'][..., + -(num_tokens_for_output_img + 1):, + :] # +1 is for the timestep token + processed_data['position_ids'] = processed_data['position_ids'][:, + -(num_tokens_for_output_img + 1):] if num_cfg == 2: cond, uncond, img_cond = torch.split(noise_pred, len(noise_pred) // 3, dim=0) @@ -581,4 +581,3 @@ def __call__( return (image,) return ImagePipelineOutput(images=image) - diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py index 545c9b001c7d..af52b20b55db 100644 --- a/src/diffusers/pipelines/omnigen/processor_omnigen.py +++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py @@ -1,3 +1,17 @@ +# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import re from typing import Dict, List @@ -7,7 +21,6 @@ from torchvision import transforms - def crop_image(pil_image, max_image_size): """ Crop the image so that its height and width does not exceed `max_image_size`, @@ -291,4 +304,3 @@ def __call__(self, features): "input_image_sizes": all_image_sizes, } return data - From 4fef9c8e2f841454739e66764ce6cf2bd041b19d Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Sun, 8 Dec 2024 16:28:12 +0800 Subject: [PATCH 13/55] reformat --- scripts/convert_omnigen_to_diffusers.py | 214 +++++++++--------- src/diffusers/models/embeddings.py | 4 +- .../pipelines/omnigen/pipeline_omnigen.py | 6 +- .../scheduling_flow_match_euler_discrete.py | 3 +- test.py | 59 ----- 5 files changed, 113 insertions(+), 173 deletions(-) delete mode 100644 test.py diff --git a/scripts/convert_omnigen_to_diffusers.py b/scripts/convert_omnigen_to_diffusers.py index 8c9bc8bdb457..4e8b284d2fb7 100644 --- a/scripts/convert_omnigen_to_diffusers.py +++ b/scripts/convert_omnigen_to_diffusers.py @@ -1,11 +1,10 @@ import argparse import os -os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2' import torch -from safetensors.torch import load_file -from transformers import AutoModel, AutoTokenizer, AutoConfig from huggingface_hub import snapshot_download +from safetensors.torch import load_file +from transformers import AutoTokenizer from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenTransformer2DModel, OmniGenPipeline @@ -17,9 +16,9 @@ def main(args): print("Model not found, downloading...") cache_folder = os.getenv('HF_HUB_CACHE') args.origin_ckpt_path = snapshot_download(repo_id=args.origin_ckpt_path, - cache_dir=cache_folder, - ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', - 'model.pt']) + cache_dir=cache_folder, + ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', + 'model.pt']) print(f"Downloaded model to {args.origin_ckpt_path}") ckpt = os.path.join(args.origin_ckpt_path, 'model.safetensors') @@ -48,7 +47,7 @@ def main(args): # transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path) # print(type(transformer_config.__dict__)) # print(transformer_config.__dict__) - + transformer_config = { "_name_or_path": "Phi-3-vision-128k-instruct", "architectures": [ @@ -70,104 +69,104 @@ def main(args): "rms_norm_eps": 1e-05, "rope_scaling": { "long_factor": [ - 1.0299999713897705, - 1.0499999523162842, - 1.0499999523162842, - 1.0799999237060547, - 1.2299998998641968, - 1.2299998998641968, - 1.2999999523162842, - 1.4499999284744263, - 1.5999999046325684, - 1.6499998569488525, - 1.8999998569488525, - 2.859999895095825, - 3.68999981880188, - 5.419999599456787, - 5.489999771118164, - 5.489999771118164, - 9.09000015258789, - 11.579999923706055, - 15.65999984741211, - 15.769999504089355, - 15.789999961853027, - 18.360000610351562, - 21.989999771118164, - 23.079999923706055, - 30.009998321533203, - 32.35000228881836, - 32.590003967285156, - 35.56000518798828, - 39.95000457763672, - 53.840003967285156, - 56.20000457763672, - 57.95000457763672, - 59.29000473022461, - 59.77000427246094, - 59.920005798339844, - 61.190006256103516, - 61.96000671386719, - 62.50000762939453, - 63.3700065612793, - 63.48000717163086, - 63.48000717163086, - 63.66000747680664, - 63.850006103515625, - 64.08000946044922, - 64.760009765625, - 64.80001068115234, - 64.81001281738281, - 64.81001281738281 + 1.0299999713897705, + 1.0499999523162842, + 1.0499999523162842, + 1.0799999237060547, + 1.2299998998641968, + 1.2299998998641968, + 1.2999999523162842, + 1.4499999284744263, + 1.5999999046325684, + 1.6499998569488525, + 1.8999998569488525, + 2.859999895095825, + 3.68999981880188, + 5.419999599456787, + 5.489999771118164, + 5.489999771118164, + 9.09000015258789, + 11.579999923706055, + 15.65999984741211, + 15.769999504089355, + 15.789999961853027, + 18.360000610351562, + 21.989999771118164, + 23.079999923706055, + 30.009998321533203, + 32.35000228881836, + 32.590003967285156, + 35.56000518798828, + 39.95000457763672, + 53.840003967285156, + 56.20000457763672, + 57.95000457763672, + 59.29000473022461, + 59.77000427246094, + 59.920005798339844, + 61.190006256103516, + 61.96000671386719, + 62.50000762939453, + 63.3700065612793, + 63.48000717163086, + 63.48000717163086, + 63.66000747680664, + 63.850006103515625, + 64.08000946044922, + 64.760009765625, + 64.80001068115234, + 64.81001281738281, + 64.81001281738281 ], "short_factor": [ - 1.05, - 1.05, - 1.05, - 1.1, - 1.1, - 1.1, - 1.2500000000000002, - 1.2500000000000002, - 1.4000000000000004, - 1.4500000000000004, - 1.5500000000000005, - 1.8500000000000008, - 1.9000000000000008, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.1000000000000005, - 2.1000000000000005, - 2.2, - 2.3499999999999996, - 2.3499999999999996, - 2.3499999999999996, - 2.3499999999999996, - 2.3999999999999995, - 2.3999999999999995, - 2.6499999999999986, - 2.6999999999999984, - 2.8999999999999977, - 2.9499999999999975, - 3.049999999999997, - 3.049999999999997, - 3.049999999999997 + 1.05, + 1.05, + 1.05, + 1.1, + 1.1, + 1.1, + 1.2500000000000002, + 1.2500000000000002, + 1.4000000000000004, + 1.4500000000000004, + 1.5500000000000005, + 1.8500000000000008, + 1.9000000000000008, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.1000000000000005, + 2.1000000000000005, + 2.2, + 2.3499999999999996, + 2.3499999999999996, + 2.3499999999999996, + 2.3499999999999996, + 2.3999999999999995, + 2.3999999999999995, + 2.6499999999999986, + 2.6999999999999984, + 2.8999999999999977, + 2.9499999999999975, + 3.049999999999997, + 3.049999999999997, + 3.049999999999997 ], "type": "su" }, @@ -179,7 +178,7 @@ def main(args): "use_cache": True, "vocab_size": 32064, "_attn_implementation": "sdpa" - } + } transformer = OmniGenTransformer2DModel( transformer_config=transformer_config, patch_size=2, @@ -198,7 +197,6 @@ def main(args): tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path) - pipeline = OmniGenPipeline( tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler ) @@ -209,10 +207,12 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument( - "--origin_ckpt_path", default="Shitao/OmniGen-v1", type=str, required=False, help="Path to the checkpoint to convert." + "--origin_ckpt_path", default="Shitao/OmniGen-v1", type=str, required=False, + help="Path to the checkpoint to convert." ) - parser.add_argument("--dump_path", default="/share/shitao/repos/OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline.") + parser.add_argument("--dump_path", default="OmniGen-v1-diffusers", type=str, required=False, + help="Path to the output pipeline.") args = parser.parse_args() main(args) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 720a48f3f747..412af10b834c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -358,8 +358,8 @@ def forward(self, ): """ Args: - latent: - is_input_image: + latent: encoded image latents + is_input_image: use input_image_proj or output_image_proj padding_latent: When sizes of target images are inconsistent, use `padding_latent` to maintain consistent sequence length. Returns: torch.Tensor diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 603ef39476e8..44fed3490843 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -49,13 +49,13 @@ >>> import torch >>> from diffusers import OmniGenPipeline - >>> pipe = OmniGenPipeline.from_pretrained("****", torch_dtype=torch.bfloat16) + >>> pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A cat holding a sign that says hello world" >>> # Depending on the variant being used, the pipeline call will slightly vary. >>> # Refer to the pipeline documentation for more details. - >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] - >>> image.save("flux.png") + >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0] + >>> image.save("t2i.png") ``` """ diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 7e01160b8626..56fe71929d13 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -212,7 +212,7 @@ def set_timesteps( else: sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) - self.timesteps = timesteps.to(device=device) + self.timesteps = timesteps.to(device=device) self.sigmas = sigmas self._step_index = None self._begin_index = None @@ -300,7 +300,6 @@ def step( sigma = self.sigmas[self.step_index] sigma_next = self.sigmas[self.step_index + 1] - prev_sample = sample + (sigma_next - sigma) * model_output # Cast sample back to model compatible dtype diff --git a/test.py b/test.py deleted file mode 100644 index b27d99a1066b..000000000000 --- a/test.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2' - -# from huggingface_hub import snapshot_download - -# from diffusers.models import OmniGenTransformer2DModel -# from transformers import Phi3Model, Phi3Config - - -# from safetensors.torch import load_file - -# model_name = "Shitao/OmniGen-v1" -# config = Phi3Config.from_pretrained("Shitao/OmniGen-v1") -# model = OmniGenTransformer2DModel(transformer_config=config) -# cache_folder = os.getenv('HF_HUB_CACHE') -# model_name = snapshot_download(repo_id=model_name, -# cache_dir=cache_folder, -# ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) -# print(model_name) -# model_path = os.path.join(model_name, 'model.safetensors') -# ckpt = load_file(model_path, 'cpu') - - -# mapping_dict = { -# "pos_embed": "patch_embedding.pos_embed", -# "x_embedder.proj.weight": "patch_embedding.output_image_proj.weight", -# "x_embedder.proj.bias": "patch_embedding.output_image_proj.bias", -# "input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight", -# "input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias", -# "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight", -# "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias", -# "final_layer.linear.weight": "proj_out.weight", -# "final_layer.linear.bias": "proj_out.bias", - -# } - -# new_ckpt = {} -# for k, v in ckpt.items(): -# # new_ckpt[k] = v -# if k in mapping_dict: -# new_ckpt[mapping_dict[k]] = v -# else: -# new_ckpt[k] = v - - - -# model.load_state_dict(new_ckpt) - - -from tests.pipelines.omnigen.test_pipeline_omnigen import OmniGenPipelineFastTests, OmniGenPipelineSlowTests - -test1 = OmniGenPipelineFastTests() -test1.test_inference() - -test2 = OmniGenPipelineSlowTests() -test2.test_omnigen_inference() - - - From f2fc182c89dcefc21fe8e65ea731112203867313 Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Sun, 8 Dec 2024 16:29:38 +0800 Subject: [PATCH 14/55] reformat --- src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 56fe71929d13..d4a970720f8e 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -300,6 +300,7 @@ def step( sigma = self.sigmas[self.step_index] sigma_next = self.sigmas[self.step_index + 1] + prev_sample = sample + (sigma_next - sigma) * model_output # Cast sample back to model compatible dtype From 5f3148d801e6d46ce05c42a965fd78ae1c60f350 Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Sun, 8 Dec 2024 16:31:52 +0800 Subject: [PATCH 15/55] reformat --- src/diffusers/pipelines/lumina/pipeline_lumina.py | 2 +- .../schedulers/scheduling_flow_match_euler_discrete.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 1db3d58808d2..018f2e8bf1bc 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -777,7 +777,7 @@ def __call__( # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, + self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latents. diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index d4a970720f8e..c1096dbe0c29 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -300,7 +300,7 @@ def step( sigma = self.sigmas[self.step_index] sigma_next = self.sigmas[self.step_index + 1] - + prev_sample = sample + (sigma_next - sigma) * model_output # Cast sample back to model compatible dtype From 178d377622654de02c55d11f6cda73949ca06f8d Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Sun, 8 Dec 2024 17:09:31 +0800 Subject: [PATCH 16/55] update docs --- docs/source/en/using-diffusers/multimodal2img.md | 4 +++- docs/source/en/using-diffusers/omnigen.md | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/source/en/using-diffusers/multimodal2img.md b/docs/source/en/using-diffusers/multimodal2img.md index 10567d49ae19..1aabb99d5879 100644 --- a/docs/source/en/using-diffusers/multimodal2img.md +++ b/docs/source/en/using-diffusers/multimodal2img.md @@ -16,7 +16,9 @@ specific language governing permissions and limitations under the License. -Multimodal instructions mean you can input any sequence of mixed text and images to guide image generation. You can input multiple images and use prompts to describe the desired output. This approach is more flexible than using only text or images. +Multimodal instructions mean you can input arbitrarily interleaved text and image inputs as conditions to guide image generation. +You can input multiple images and use prompts to describe the desired output. +This approach is more flexible than using only text or images. ## Examples diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md index c535f4358155..833199641644 100644 --- a/docs/source/en/using-diffusers/omnigen.md +++ b/docs/source/en/using-diffusers/omnigen.md @@ -274,6 +274,7 @@ image = pipe( guidance_scale=2.5, img_guidance_scale=1.6, generator=torch.Generator(device="cpu").manual_seed(666)).images[0] +image ```
From 08c05f9f46590a45670d2d240bae866ab488603e Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Sun, 8 Dec 2024 17:10:52 +0800 Subject: [PATCH 17/55] update docs --- docs/source/en/api/pipelines/omnigen.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/omnigen.md b/docs/source/en/api/pipelines/omnigen.md index b7ea5488156a..52bcf8d2c3b2 100644 --- a/docs/source/en/api/pipelines/omnigen.md +++ b/docs/source/en/api/pipelines/omnigen.md @@ -36,7 +36,9 @@ extra intermediate steps, greatly simplifying the image generation workflow. 3) learning in a unified format, OmniGen effectively transfers knowledge across different tasks, manages unseen tasks and domains, and exhibits novel capabilities. We also explore -the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. This work represents the first attempt at a general-purpose image generation model, and we will release our resources at https: +the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. +This work represents the first attempt at a general-purpose image generation model, +and we will release our resources at https: //github.com/VectorSpaceLab/OmniGen to foster future advancements.* From 286990d02accdbaed2df7ebb434c276a46834c7e Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Sun, 8 Dec 2024 19:17:15 +0800 Subject: [PATCH 18/55] make style --- .../geodiff_molecule_conformation.ipynb | 7230 +++++++++-------- examples/research_projects/gligen/demo.ipynb | 13 +- scripts/convert_omnigen_to_diffusers.py | 2 +- src/diffusers/models/embeddings.py | 2 +- .../transformers/transformer_omnigen.py | 6 +- .../pipelines/omnigen/kvcache_omnigen.py | 3 +- .../pipelines/omnigen/pipeline_omnigen.py | 5 +- .../pipelines/omnigen/processor_omnigen.py | 6 +- .../omnigen/test_pipeline_omnigen.py | 4 +- 9 files changed, 3639 insertions(+), 3632 deletions(-) diff --git a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb index bde093802a5d..03f58f1f2f63 100644 --- a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb +++ b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb @@ -1,3652 +1,3660 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "F88mignPnalS" - }, - "source": [ - "# Introduction\n", - "\n", - "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n", - "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n", - "\n", - "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n", - "\n", - "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n", - "\n", - "> Colab made by [natolambert](https://twitter.com/natolambert).\n", - "\n", - "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7cnwXMocnuzB" - }, - "source": [ - "## Installations\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Install Conda" - ], - "metadata": { - "id": "ff9SxWnaNId9" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1g_6zOabItDk" - }, - "source": [ - "Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "K0ofXobG5Y-X", - "outputId": "572c3d25-6f19-4c1e-83f5-a1d084a3207f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "nvcc: NVIDIA (R) Cuda compiler driver\n", - "Copyright (c) 2005-2021 NVIDIA Corporation\n", - "Built on Sun_Feb_14_21:12:58_PST_2021\n", - "Cuda compilation tools, release 11.2, V11.2.152\n", - "Build cuda_11.2.r11.2/compiler.29618528_0\n" - ] - } - ], - "source": [ - "!nvcc --version" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VfthW90vI0nw" - }, - "source": [ - "Install Conda for some more complex dependencies for geometric networks." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2WNFzSnbiE0k", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "!pip install -q condacolab" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NUsbWYCUI7Km" - }, - "source": [ - "Setup Conda" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FZelreINdmd0", - "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "✨🍰✨ Everything looks OK!\n" - ] - } - ], - "source": [ - "import condacolab\n", - "condacolab.install()" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "F88mignPnalS" + }, + "source": [ + "# Introduction\n", + "\n", + "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n", + "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n", + "\n", + "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n", + "\n", + "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n", + "\n", + "> Colab made by [natolambert](https://twitter.com/natolambert).\n", + "\n", + "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7cnwXMocnuzB" + }, + "source": [ + "## Installations\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ff9SxWnaNId9" + }, + "source": [ + "### Install Conda" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1g_6zOabItDk" + }, + "source": [ + "Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "K0ofXobG5Y-X", + "outputId": "572c3d25-6f19-4c1e-83f5-a1d084a3207f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "nvcc: NVIDIA (R) Cuda compiler driver\n", + "Copyright (c) 2005-2021 NVIDIA Corporation\n", + "Built on Sun_Feb_14_21:12:58_PST_2021\n", + "Cuda compilation tools, release 11.2, V11.2.152\n", + "Build cuda_11.2.r11.2/compiler.29618528_0\n" + ] + } + ], + "source": [ + "!nvcc --version" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VfthW90vI0nw" + }, + "source": [ + "Install Conda for some more complex dependencies for geometric networks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2WNFzSnbiE0k", + "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install -q condacolab" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NUsbWYCUI7Km" + }, + "source": [ + "Setup Conda" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FZelreINdmd0", + "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✨🍰✨ Everything looks OK!\n" + ] + } + ], + "source": [ + "import condacolab\n", + "\n", + "\n", + "condacolab.install()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JzDHaPU7I9Sn" + }, + "source": [ + "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JMxRjHhL7w8V", + "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n", + "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "\n", + "## Package Plan ##\n", + "\n", + " environment location: /usr/local\n", + "\n", + " added / updated specs:\n", + " - cudatoolkit=11.1\n", + " - pytorch\n", + " - torchaudio\n", + " - torchvision\n", + "\n", + "\n", + "The following packages will be downloaded:\n", + "\n", + " package | build\n", + " ---------------------------|-----------------\n", + " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n", + " ------------------------------------------------------------\n", + " Total: 960 KB\n", + "\n", + "The following packages will be UPDATED:\n", + "\n", + " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n", + "\n", + "\n", + "\n", + "Downloading and Extracting Packages\n", + "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n", + "Preparing transaction: / \b\bdone\n", + "Verifying transaction: \\ \b\bdone\n", + "Executing transaction: / \b\bdone\n", + "Retrieving notices: ...working... done\n" + ] + } + ], + "source": [ + "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n", + "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QDS6FPZ0Tu5b" + }, + "source": [ + "Need to remove a pathspec for colab that specifies the incorrect cuda version." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dq1lxR10TtrR", + "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n" + ] + } + ], + "source": [ + "!rm /usr/local/conda-meta/pinned" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z1L3DdZOJB30" + }, + "source": [ + "Install torch geometric (used in the model later)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "D5ukfCOWfjzK", + "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", + "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "\n", + "## Package Plan ##\n", + "\n", + " environment location: /usr/local\n", + "\n", + " added / updated specs:\n", + " - pytorch-geometric=1.7.2\n", + "\n", + "\n", + "The following packages will be downloaded:\n", + "\n", + " package | build\n", + " ---------------------------|-----------------\n", + " decorator-4.4.2 | py_0 11 KB conda-forge\n", + " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n", + " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n", + " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n", + " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n", + " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n", + " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n", + " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n", + " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n", + " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n", + " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n", + " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n", + " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n", + " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n", + " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n", + " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n", + " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n", + " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n", + " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n", + " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n", + " ------------------------------------------------------------\n", + " Total: 55.9 MB\n", + "\n", + "The following NEW packages will be INSTALLED:\n", + "\n", + " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n", + " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n", + " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n", + " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n", + " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n", + " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n", + " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n", + " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n", + " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n", + " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n", + " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n", + " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n", + " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n", + " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n", + " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n", + " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n", + " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n", + " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n", + " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n", + "\n", + "The following packages will be DOWNGRADED:\n", + "\n", + " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n", + "\n", + "\n", + "\n", + "Downloading and Extracting Packages\n", + "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n", + "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n", + "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n", + "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n", + "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n", + "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n", + "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n", + "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n", + "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n", + "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n", + "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n", + "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n", + "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n", + "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n", + "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n", + "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n", + "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n", + "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n", + "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n", + "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n", + "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n", + "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", + "Retrieving notices: ...working... done\n" + ] + } + ], + "source": [ + "!conda install -c rusty1s pytorch-geometric=1.7.2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ppxv6Mdkalbc" + }, + "source": [ + "### Install Diffusers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mgQA_XN-XGY2", + "outputId": "85392615-b6a4-4052-9d2a-79604be62c94" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/content\n", + "Cloning into 'diffusers'...\n", + "remote: Enumerating objects: 9298, done.\u001b[K\n", + "remote: Counting objects: 100% (40/40), done.\u001b[K\n", + "remote: Compressing objects: 100% (23/23), done.\u001b[K\n", + "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n", + "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n", + "Resolving deltas: 100% (6168/6168), done.\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "%cd /content\n", + "\n", + "# install latest HF diffusers (will update to the release once added)\n", + "!git clone https://github.com/huggingface/diffusers.git\n", + "!pip install -q /content/diffusers\n", + "\n", + "# dependencies for diffusers\n", + "!pip install -q datasets transformers" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LZO6AJKuJKO8" + }, + "source": [ + "Check that torch is installed correctly and utilizing the GPU in the colab" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 53 }, + "id": "gZt7BNi1e1PA", + "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "JzDHaPU7I9Sn" - }, - "source": [ - "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JMxRjHhL7w8V", - "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8" + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n", - "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "\n", - "## Package Plan ##\n", - "\n", - " environment location: /usr/local\n", - "\n", - " added / updated specs:\n", - " - cudatoolkit=11.1\n", - " - pytorch\n", - " - torchaudio\n", - " - torchvision\n", - "\n", - "\n", - "The following packages will be downloaded:\n", - "\n", - " package | build\n", - " ---------------------------|-----------------\n", - " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n", - " ------------------------------------------------------------\n", - " Total: 960 KB\n", - "\n", - "The following packages will be UPDATED:\n", - "\n", - " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n", - "\n", - "\n", - "\n", - "Downloading and Extracting Packages\n", - "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n", - "Preparing transaction: / \b\bdone\n", - "Verifying transaction: \\ \b\bdone\n", - "Executing transaction: / \b\bdone\n", - "Retrieving notices: ...working... done\n" - ] - } - ], - "source": [ - "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n", - "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge" + "text/plain": [ + "'1.8.2'" ] - }, - { - "cell_type": "markdown", - "source": [ - "Need to remove a pathspec for colab that specifies the incorrect cuda version." - ], - "metadata": { - "id": "QDS6FPZ0Tu5b" + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "\n", + "\n", + "print(torch.cuda.is_available())\n", + "torch.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KLE7CqlfJNUO" + }, + "source": [ + "### Install Chemistry-specific Dependencies\n", + "\n", + "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0CPv_NvehRz3", + "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting rdkit\n", + " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n", + "Installing collected packages: rdkit\n", + "Successfully installed rdkit-2022.3.5\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install rdkit" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "88GaDbDPxJ5I" + }, + "source": [ + "### Get viewer from nglview\n", + "\n", + "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n", + "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n", + "The rdmol in this object is a source of ground truth for the generated molecules.\n", + "\n", + "You will use one rendering function from nglviewer later!\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "jcl8GCS2mz6t", + "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting nglview\n", + " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n", + "Collecting jupyterlab-widgets\n", + " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipywidgets>=7\n", + " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting widgetsnbextension~=4.0\n", + " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipython>=6.1.0\n", + " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipykernel>=4.5.1\n", + " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting traitlets>=4.3.1\n", + " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n", + "Collecting pyzmq>=17\n", + " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting matplotlib-inline>=0.1\n", + " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n", + "Collecting tornado>=6.1\n", + " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nest-asyncio\n", + " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n", + "Collecting debugpy>=1.0\n", + " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting psutil\n", + " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting jupyter-client>=6.1.12\n", + " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pickleshare\n", + " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n", + "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n", + "Collecting backcall\n", + " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n", + "Collecting pexpect>4.3\n", + " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pygments\n", + " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting jedi>=0.16\n", + " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n", + " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n", + "Collecting parso<0.9.0,>=0.8.0\n", + " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n", + "Collecting entrypoints\n", + " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n", + "Collecting jupyter-core>=4.9.2\n", + " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ptyprocess>=0.5\n", + " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n", + "Collecting wcwidth\n", + " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n", + "Building wheels for collected packages: nglview\n", + " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n", + " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n", + "Successfully built nglview\n", + "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n", + "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + }, + { + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "pexpect", + "pickleshare", + "wcwidth" + ] + } } - }, - { - "cell_type": "code", - "source": [ - "!rm /usr/local/conda-meta/pinned" - ], - "metadata": { - "id": "dq1lxR10TtrR", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Z1L3DdZOJB30" - }, - "source": [ - "Install torch geometric (used in the model later)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "D5ukfCOWfjzK", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", - "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "\n", - "## Package Plan ##\n", - "\n", - " environment location: /usr/local\n", - "\n", - " added / updated specs:\n", - " - pytorch-geometric=1.7.2\n", - "\n", - "\n", - "The following packages will be downloaded:\n", - "\n", - " package | build\n", - " ---------------------------|-----------------\n", - " decorator-4.4.2 | py_0 11 KB conda-forge\n", - " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n", - " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n", - " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n", - " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n", - " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n", - " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n", - " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n", - " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n", - " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n", - " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n", - " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n", - " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n", - " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n", - " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n", - " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n", - " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n", - " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n", - " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n", - " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n", - " ------------------------------------------------------------\n", - " Total: 55.9 MB\n", - "\n", - "The following NEW packages will be INSTALLED:\n", - "\n", - " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n", - " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n", - " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n", - " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n", - " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n", - " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n", - " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n", - " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n", - " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n", - " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n", - " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n", - " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n", - " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n", - " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n", - " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n", - " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n", - " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n", - " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n", - " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n", - "\n", - "The following packages will be DOWNGRADED:\n", - "\n", - " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n", - "\n", - "\n", - "\n", - "Downloading and Extracting Packages\n", - "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n", - "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n", - "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n", - "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n", - "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n", - "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n", - "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n", - "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n", - "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n", - "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n", - "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n", - "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n", - "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n", - "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n", - "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n", - "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n", - "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n", - "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n", - "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n", - "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n", - "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n", - "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", - "Retrieving notices: ...working... done\n" - ] - } - ], - "source": [ - "!conda install -c rusty1s pytorch-geometric=1.7.2" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ppxv6Mdkalbc" - }, - "source": [ - "### Install Diffusers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mgQA_XN-XGY2", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "85392615-b6a4-4052-9d2a-79604be62c94" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "!pip install nglview" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8t8_e_uVLdKB" + }, + "source": [ + "## Create a diffusion model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "G0rMncVtNSqU" + }, + "source": [ + "### Model class(es)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "L5FEXz5oXkzt" + }, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-3-P4w5sXkRU" + }, + "outputs": [], + "source": [ + "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n", + "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n", + "from dataclasses import dataclass\n", + "from typing import Callable, Tuple, Union\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import Tensor, nn\n", + "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n", + "from torch_geometric.nn import MessagePassing, radius, radius_graph\n", + "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n", + "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n", + "from torch_scatter import scatter_add\n", + "from torch_sparse import SparseTensor, coalesce\n", + "\n", + "from diffusers.configuration_utils import ConfigMixin, register_to_config\n", + "from diffusers.modeling_utils import ModelMixin\n", + "from diffusers.utils import BaseOutput\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EzJQXPN_XrMX" + }, + "source": [ + "Helper classes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oR1Y56QiLY90" + }, + "outputs": [], + "source": [ + "@dataclass\n", + "class MoleculeGNNOutput(BaseOutput):\n", + " \"\"\"\n", + " Args:\n", + " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n", + " Hidden states output. Output of last layer of model.\n", + " \"\"\"\n", + "\n", + " sample: torch.Tensor\n", + "\n", + "\n", + "class MultiLayerPerceptron(nn.Module):\n", + " \"\"\"\n", + " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n", + " Args:\n", + " input_dim (int): input dimension\n", + " hidden_dim (list of int): hidden dimensions\n", + " activation (str or function, optional): activation function\n", + " dropout (float, optional): dropout rate\n", + " \"\"\"\n", + "\n", + " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n", + " super(MultiLayerPerceptron, self).__init__()\n", + "\n", + " self.dims = [input_dim] + hidden_dims\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " print(f\"Warning, activation passed {activation} is not string and ignored\")\n", + " self.activation = None\n", + " if dropout > 0:\n", + " self.dropout = nn.Dropout(dropout)\n", + " else:\n", + " self.dropout = None\n", + "\n", + " self.layers = nn.ModuleList()\n", + " for i in range(len(self.dims) - 1):\n", + " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n", + "\n", + " def forward(self, x):\n", + " \"\"\"\"\"\"\n", + " for i, layer in enumerate(self.layers):\n", + " x = layer(x)\n", + " if i < len(self.layers) - 1:\n", + " if self.activation:\n", + " x = self.activation(x)\n", + " if self.dropout:\n", + " x = self.dropout(x)\n", + " return x\n", + "\n", + "\n", + "class ShiftedSoftplus(torch.nn.Module):\n", + " def __init__(self):\n", + " super(ShiftedSoftplus, self).__init__()\n", + " self.shift = torch.log(torch.tensor(2.0)).item()\n", + "\n", + " def forward(self, x):\n", + " return F.softplus(x) - self.shift\n", + "\n", + "\n", + "class CFConv(MessagePassing):\n", + " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n", + " super(CFConv, self).__init__(aggr=\"add\")\n", + " self.lin1 = Linear(in_channels, num_filters, bias=False)\n", + " self.lin2 = Linear(num_filters, out_channels)\n", + " self.nn = mlp\n", + " self.cutoff = cutoff\n", + " self.smooth = smooth\n", + "\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " torch.nn.init.xavier_uniform_(self.lin1.weight)\n", + " torch.nn.init.xavier_uniform_(self.lin2.weight)\n", + " self.lin2.bias.data.fill_(0)\n", + "\n", + " def forward(self, x, edge_index, edge_length, edge_attr):\n", + " if self.smooth:\n", + " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n", + " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n", + " else:\n", + " C = (edge_length <= self.cutoff).float()\n", + " W = self.nn(edge_attr) * C.view(-1, 1)\n", + "\n", + " x = self.lin1(x)\n", + " x = self.propagate(edge_index, x=x, W=W)\n", + " x = self.lin2(x)\n", + " return x\n", + "\n", + " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n", + " return x_j * W\n", + "\n", + "\n", + "class InteractionBlock(torch.nn.Module):\n", + " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n", + " super(InteractionBlock, self).__init__()\n", + " mlp = Sequential(\n", + " Linear(num_gaussians, num_filters),\n", + " ShiftedSoftplus(),\n", + " Linear(num_filters, num_filters),\n", + " )\n", + " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n", + " self.act = ShiftedSoftplus()\n", + " self.lin = Linear(hidden_channels, hidden_channels)\n", + "\n", + " def forward(self, x, edge_index, edge_length, edge_attr):\n", + " x = self.conv(x, edge_index, edge_length, edge_attr)\n", + " x = self.act(x)\n", + " x = self.lin(x)\n", + " return x\n", + "\n", + "\n", + "class SchNetEncoder(Module):\n", + " def __init__(\n", + " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.hidden_channels = hidden_channels\n", + " self.num_filters = num_filters\n", + " self.num_interactions = num_interactions\n", + " self.cutoff = cutoff\n", + "\n", + " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n", + "\n", + " self.interactions = ModuleList()\n", + " for _ in range(num_interactions):\n", + " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n", + " self.interactions.append(block)\n", + "\n", + " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n", + " if embed_node:\n", + " assert z.dim() == 1 and z.dtype == torch.long\n", + " h = self.embedding(z)\n", + " else:\n", + " h = z\n", + " for interaction in self.interactions:\n", + " h = h + interaction(h, edge_index, edge_length, edge_attr)\n", + "\n", + " return h\n", + "\n", + "\n", + "class GINEConv(MessagePassing):\n", + " \"\"\"\n", + " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n", + " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n", + " \"\"\"\n", + "\n", + " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n", + " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n", + " self.nn = mlp\n", + " self.initial_eps = eps\n", + "\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " self.activation = None\n", + "\n", + " if train_eps:\n", + " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n", + " else:\n", + " self.register_buffer(\"eps\", torch.Tensor([eps]))\n", + "\n", + " def forward(\n", + " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n", + " ) -> torch.Tensor:\n", + " \"\"\"\"\"\"\n", + " if isinstance(x, torch.Tensor):\n", + " x: OptPairTensor = (x, x)\n", + "\n", + " # Node and edge feature dimensionalites need to match.\n", + " if isinstance(edge_index, torch.Tensor):\n", + " assert edge_attr is not None\n", + " assert x[0].size(-1) == edge_attr.size(-1)\n", + " elif isinstance(edge_index, SparseTensor):\n", + " assert x[0].size(-1) == edge_index.size(-1)\n", + "\n", + " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n", + " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n", + "\n", + " x_r = x[1]\n", + " if x_r is not None:\n", + " out += (1 + self.eps) * x_r\n", + "\n", + " return self.nn(out)\n", + "\n", + " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n", + " if self.activation:\n", + " return self.activation(x_j + edge_attr)\n", + " else:\n", + " return x_j + edge_attr\n", + "\n", + " def __repr__(self):\n", + " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n", + "\n", + "\n", + "class GINEncoder(torch.nn.Module):\n", + " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n", + " super().__init__()\n", + "\n", + " self.hidden_dim = hidden_dim\n", + " self.num_convs = num_convs\n", + " self.short_cut = short_cut\n", + " self.concat_hidden = concat_hidden\n", + " self.node_emb = nn.Embedding(100, hidden_dim)\n", + "\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " self.activation = None\n", + "\n", + " self.convs = nn.ModuleList()\n", + " for i in range(self.num_convs):\n", + " self.convs.append(\n", + " GINEConv(\n", + " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n", + " activation=activation,\n", + " )\n", + " )\n", + "\n", + " def forward(self, z, edge_index, edge_attr):\n", + " \"\"\"\n", + " Input:\n", + " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n", + " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n", + " Output:\n", + " node_feature: graph feature\n", + " \"\"\"\n", + "\n", + " node_attr = self.node_emb(z) # (num_node, hidden)\n", + "\n", + " hiddens = []\n", + " conv_input = node_attr # (num_node, hidden)\n", + "\n", + " for conv_idx, conv in enumerate(self.convs):\n", + " hidden = conv(conv_input, edge_index, edge_attr)\n", + " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n", + " hidden = self.activation(hidden)\n", + " assert hidden.shape == conv_input.shape\n", + " if self.short_cut and hidden.shape == conv_input.shape:\n", + " hidden += conv_input\n", + "\n", + " hiddens.append(hidden)\n", + " conv_input = hidden\n", + "\n", + " if self.concat_hidden:\n", + " node_feature = torch.cat(hiddens, dim=-1)\n", + " else:\n", + " node_feature = hiddens[-1]\n", + "\n", + " return node_feature\n", + "\n", + "\n", + "class MLPEdgeEncoder(Module):\n", + " def __init__(self, hidden_dim=100, activation=\"relu\"):\n", + " super().__init__()\n", + " self.hidden_dim = hidden_dim\n", + " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n", + " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n", + "\n", + " @property\n", + " def out_channels(self):\n", + " return self.hidden_dim\n", + "\n", + " def forward(self, edge_length, edge_type):\n", + " \"\"\"\n", + " Input:\n", + " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n", + " Returns:\n", + " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n", + " \"\"\"\n", + " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n", + " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n", + " return d_emb * edge_attr # (num_edge, hidden)\n", + "\n", + "\n", + "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n", + " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n", + " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n", + " return h_pair\n", + "\n", + "\n", + "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n", + " \"\"\"\n", + " Args:\n", + " num_nodes: Number of atoms.\n", + " edge_index: Bond indices of the original graph.\n", + " edge_type: Bond types of the original graph.\n", + " order: Extension order.\n", + " Returns:\n", + " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n", + " \"\"\"\n", + "\n", + " def binarize(x):\n", + " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n", + "\n", + " def get_higher_order_adj_matrix(adj, order):\n", + " \"\"\"\n", + " Args:\n", + " adj: (N, N)\n", + " type_mat: (N, N)\n", + " Returns:\n", + " Following attributes will be updated:\n", + " - edge_index\n", + " - edge_type\n", + " Following attributes will be added to the data object:\n", + " - bond_edge_index: Original edge_index.\n", + " \"\"\"\n", + " adj_mats = [\n", + " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n", + " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n", + " ]\n", + "\n", + " for i in range(2, order + 1):\n", + " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n", + " order_mat = torch.zeros_like(adj)\n", + "\n", + " for i in range(1, order + 1):\n", + " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n", + "\n", + " return order_mat\n", + "\n", + " num_types = 22\n", + " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n", + " # from rdkit.Chem.rdchem import BondType as BT\n", + " N = num_nodes\n", + " adj = to_dense_adj(edge_index).squeeze(0)\n", + " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n", + "\n", + " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n", + " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n", + " assert (type_mat * type_highorder == 0).all()\n", + " type_new = type_mat + type_highorder\n", + "\n", + " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n", + " _, edge_order = dense_to_sparse(adj_order)\n", + "\n", + " # data.bond_edge_index = data.edge_index # Save original edges\n", + " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n", + "\n", + " return new_edge_index, new_edge_type\n", + "\n", + "\n", + "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n", + " assert edge_type.dim() == 1\n", + " N = pos.size(0)\n", + "\n", + " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n", + "\n", + " if is_sidechain is None:\n", + " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n", + " else:\n", + " # fetch sidechain and its batch index\n", + " is_sidechain = is_sidechain.bool()\n", + " dummy_index = torch.arange(pos.size(0), device=pos.device)\n", + " sidechain_pos = pos[is_sidechain]\n", + " sidechain_index = dummy_index[is_sidechain]\n", + " sidechain_batch = batch[is_sidechain]\n", + "\n", + " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n", + " r_edge_index_x = assign_index[1]\n", + " r_edge_index_y = assign_index[0]\n", + " r_edge_index_y = sidechain_index[r_edge_index_y]\n", + "\n", + " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n", + " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n", + " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n", + " # delete self loop\n", + " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n", + "\n", + " rgraph_adj = torch.sparse.LongTensor(\n", + " rgraph_edge_index,\n", + " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n", + " torch.Size([N, N]),\n", + " )\n", + "\n", + " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n", + "\n", + " new_edge_index = composed_adj.indices()\n", + " new_edge_type = composed_adj.values().long()\n", + "\n", + " return new_edge_index, new_edge_type\n", + "\n", + "\n", + "def extend_graph_order_radius(\n", + " num_nodes,\n", + " pos,\n", + " edge_index,\n", + " edge_type,\n", + " batch,\n", + " order=3,\n", + " cutoff=10.0,\n", + " extend_order=True,\n", + " extend_radius=True,\n", + " is_sidechain=None,\n", + "):\n", + " if extend_order:\n", + " edge_index, edge_type = _extend_graph_order(\n", + " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n", + " )\n", + "\n", + " if extend_radius:\n", + " edge_index, edge_type = _extend_to_radius_graph(\n", + " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n", + " )\n", + "\n", + " return edge_index, edge_type\n", + "\n", + "\n", + "def get_distance(pos, edge_index):\n", + " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n", + "\n", + "\n", + "def graph_field_network(score_d, pos, edge_index, edge_length):\n", + " \"\"\"\n", + " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n", + " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n", + " \"\"\"\n", + " N = pos.size(0)\n", + " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n", + " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n", + " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n", + " ) # (N, 3)\n", + " return score_pos\n", + "\n", + "\n", + "def clip_norm(vec, limit, p=2):\n", + " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n", + " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n", + " return vec * denom\n", + "\n", + "\n", + "def is_local_edge(edge_type):\n", + " return edge_type > 0\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QWrHJFcYXyUB" + }, + "source": [ + "Main model class!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MCeZA1qQXzoK" + }, + "outputs": [], + "source": [ + "class MoleculeGNN(ModelMixin, ConfigMixin):\n", + " @register_to_config\n", + " def __init__(\n", + " self,\n", + " hidden_dim=128,\n", + " num_convs=6,\n", + " num_convs_local=4,\n", + " cutoff=10.0,\n", + " mlp_act=\"relu\",\n", + " edge_order=3,\n", + " edge_encoder=\"mlp\",\n", + " smooth_conv=True,\n", + " ):\n", + " super().__init__()\n", + " self.cutoff = cutoff\n", + " self.edge_encoder = edge_encoder\n", + " self.edge_order = edge_order\n", + "\n", + " \"\"\"\n", + " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n", + " in SchNetEncoder\n", + " \"\"\"\n", + " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", + " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", + "\n", + " \"\"\"\n", + " The graph neural network that extracts node-wise features.\n", + " \"\"\"\n", + " self.encoder_global = SchNetEncoder(\n", + " hidden_channels=hidden_dim,\n", + " num_filters=hidden_dim,\n", + " num_interactions=num_convs,\n", + " edge_channels=self.edge_encoder_global.out_channels,\n", + " cutoff=cutoff,\n", + " smooth=smooth_conv,\n", + " )\n", + " self.encoder_local = GINEncoder(\n", + " hidden_dim=hidden_dim,\n", + " num_convs=num_convs_local,\n", + " )\n", + "\n", + " \"\"\"\n", + " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n", + " gradients w.r.t. edge_length (out_dim = 1).\n", + " \"\"\"\n", + " self.grad_global_dist_mlp = MultiLayerPerceptron(\n", + " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", + " )\n", + "\n", + " self.grad_local_dist_mlp = MultiLayerPerceptron(\n", + " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", + " )\n", + "\n", + " \"\"\"\n", + " Incorporate parameters together\n", + " \"\"\"\n", + " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n", + " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n", + "\n", + " def _forward(\n", + " self,\n", + " atom_type,\n", + " pos,\n", + " bond_index,\n", + " bond_type,\n", + " batch,\n", + " time_step, # NOTE, model trained without timestep performed best\n", + " edge_index=None,\n", + " edge_type=None,\n", + " edge_length=None,\n", + " return_edges=False,\n", + " extend_order=True,\n", + " extend_radius=True,\n", + " is_sidechain=None,\n", + " ):\n", + " \"\"\"\n", + " Args:\n", + " atom_type: Types of atoms, (N, ).\n", + " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n", + " bond_type: Bond types, (E, ).\n", + " batch: Node index to graph index, (N, ).\n", + " \"\"\"\n", + " N = atom_type.size(0)\n", + " if edge_index is None or edge_type is None or edge_length is None:\n", + " edge_index, edge_type = extend_graph_order_radius(\n", + " num_nodes=N,\n", + " pos=pos,\n", + " edge_index=bond_index,\n", + " edge_type=bond_type,\n", + " batch=batch,\n", + " order=self.edge_order,\n", + " cutoff=self.cutoff,\n", + " extend_order=extend_order,\n", + " extend_radius=extend_radius,\n", + " is_sidechain=is_sidechain,\n", + " )\n", + " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n", + " local_edge_mask = is_local_edge(edge_type) # (E, )\n", + "\n", + " # with the parameterization of NCSNv2\n", + " # DDPM loss implicit handle the noise variance scale conditioning\n", + " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n", + "\n", + " # Encoding global\n", + " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", + "\n", + " # Global\n", + " node_attr_global = self.encoder_global(\n", + " z=atom_type,\n", + " edge_index=edge_index,\n", + " edge_length=edge_length,\n", + " edge_attr=edge_attr_global,\n", + " )\n", + " # Assemble pairwise features\n", + " h_pair_global = assemble_atom_pair_feature(\n", + " node_attr=node_attr_global,\n", + " edge_index=edge_index,\n", + " edge_attr=edge_attr_global,\n", + " ) # (E_global, 2H)\n", + " # Invariant features of edges (radius graph, global)\n", + " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n", + "\n", + " # Encoding local\n", + " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", + " # edge_attr += temb_edge\n", + "\n", + " # Local\n", + " node_attr_local = self.encoder_local(\n", + " z=atom_type,\n", + " edge_index=edge_index[:, local_edge_mask],\n", + " edge_attr=edge_attr_local[local_edge_mask],\n", + " )\n", + " # Assemble pairwise features\n", + " h_pair_local = assemble_atom_pair_feature(\n", + " node_attr=node_attr_local,\n", + " edge_index=edge_index[:, local_edge_mask],\n", + " edge_attr=edge_attr_local[local_edge_mask],\n", + " ) # (E_local, 2H)\n", + "\n", + " # Invariant features of edges (bond graph, local)\n", + " if isinstance(sigma_edge, torch.Tensor):\n", + " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n", + " 1.0 / sigma_edge[local_edge_mask]\n", + " ) # (E_local, 1)\n", + " else:\n", + " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n", + "\n", + " if return_edges:\n", + " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n", + " else:\n", + " return edge_inv_global, edge_inv_local\n", + "\n", + " def forward(\n", + " self,\n", + " sample,\n", + " timestep: Union[torch.Tensor, float, int],\n", + " return_dict: bool = True,\n", + " sigma=1.0,\n", + " global_start_sigma=0.5,\n", + " w_global=1.0,\n", + " extend_order=False,\n", + " extend_radius=True,\n", + " clip_local=None,\n", + " clip_global=1000.0,\n", + " ) -> Union[MoleculeGNNOutput, Tuple]:\n", + " r\"\"\"\n", + " Args:\n", + " sample: packed torch geometric object\n", + " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n", + " return_dict (`bool`, *optional*, defaults to `True`):\n", + " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n", + " Returns:\n", + " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n", + " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n", + " \"\"\"\n", + "\n", + " # unpack sample\n", + " atom_type = sample.atom_type\n", + " bond_index = sample.edge_index\n", + " bond_type = sample.edge_type\n", + " num_graphs = sample.num_graphs\n", + " pos = sample.pos\n", + "\n", + " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n", + "\n", + " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n", + " atom_type=atom_type,\n", + " pos=sample.pos,\n", + " bond_index=bond_index,\n", + " bond_type=bond_type,\n", + " batch=sample.batch,\n", + " time_step=timesteps,\n", + " return_edges=True,\n", + " extend_order=extend_order,\n", + " extend_radius=extend_radius,\n", + " ) # (E_global, 1), (E_local, 1)\n", + "\n", + " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n", + " node_eq_local = graph_field_network(\n", + " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n", + " )\n", + " if clip_local is not None:\n", + " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n", + "\n", + " # Global\n", + " if sigma < global_start_sigma:\n", + " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n", + " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n", + " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n", + " else:\n", + " node_eq_global = 0\n", + "\n", + " # Sum\n", + " eps_pos = node_eq_local + node_eq_global * w_global\n", + "\n", + " if not return_dict:\n", + " return (-eps_pos,)\n", + "\n", + " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CCIrPYSJj9wd" + }, + "source": [ + "### Load pretrained model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YdrAr6Ch--Ab" + }, + "source": [ + "#### Load a model\n", + "The model used is a design an\n", + "equivariant convolutional layer, named graph field network (GFN).\n", + "\n", + "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 172, + "referenced_widgets": [ + "d90f304e9560472eacfbdd11e46765eb", + "1c6246f15b654f4daa11c9bcf997b78c", + "c2321b3bff6f490ca12040a20308f555", + "b7feb522161f4cf4b7cc7c1a078ff12d", + "e2d368556e494ae7ae4e2e992af2cd4f", + "bbef741e76ec41b7ab7187b487a383df", + "561f742d418d4721b0670cc8dd62e22c", + "872915dd1bb84f538c44e26badabafdd", + "d022575f1fa2446d891650897f187b4d", + "fdc393f3468c432aa0ada05e238a5436", + "2c9362906e4b40189f16d14aa9a348da", + "6010fc8daa7a44d5aec4b830ec2ebaa1", + "7e0bb1b8d65249d3974200686b193be2", + "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", + "6526646be5ed415c84d1245b040e629b", + "24d31fc3576e43dd9f8301d2ef3a37ab", + "2918bfaadc8d4b1a9832522c40dfefb8", + "a4bfdca35cc54dae8812720f1b276a08", + "e4901541199b45c6a18824627692fc39", + "f915cf874246446595206221e900b2fe", + "a9e388f22a9742aaaf538e22575c9433", + "42f6c3db29d7484ba6b4f73590abd2f4" + ] + }, + "id": "DyCo0nsqjbml", + "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d90f304e9560472eacfbdd11e46765eb", + "version_major": 2, + "version_minor": 0 }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "/content\n", - "Cloning into 'diffusers'...\n", - "remote: Enumerating objects: 9298, done.\u001b[K\n", - "remote: Counting objects: 100% (40/40), done.\u001b[K\n", - "remote: Compressing objects: 100% (23/23), done.\u001b[K\n", - "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n", - "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n", - "Resolving deltas: 100% (6168/6168), done.\n", - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "%cd /content\n", - "\n", - "# install latest HF diffusers (will update to the release once added)\n", - "!git clone https://github.com/huggingface/diffusers.git\n", - "!pip install -q /content/diffusers\n", - "\n", - "# dependencies for diffusers\n", - "!pip install -q datasets transformers" + "text/plain": [ + "Downloading: 0%| | 0.00/3.27M [00:00] 124.78K 180KB/s in 0.7s \n", + "\n", + "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n", + "\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "\n", + "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n", + "dataset = torch.load('/content/molecules.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QZcmy1EvKQRk" + }, + "source": [ + "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "JVjz6iH_H6Eh", + "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gZt7BNi1e1PA", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 53 - }, - "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "True\n" - ] - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "'1.8.2'" - ], - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - } - }, - "metadata": {}, - "execution_count": 8 - } - ], - "source": [ - "import torch\n", - "print(torch.cuda.is_available())\n", - "torch.__version__" + "data": { + "text/plain": [ + "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")" ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KLE7CqlfJNUO" - }, - "source": [ - "### Install Chemistry-specific Dependencies\n", - "\n", - "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)." + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vHNiZAUxNgoy" + }, + "source": [ + "## Run the diffusion process" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jZ1KZrxKqENg" + }, + "source": [ + "#### Helper Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "s240tYueqKKf" + }, + "outputs": [], + "source": [ + "import copy\n", + "import os\n", + "\n", + "from torch_geometric.data import Batch, Data\n", + "from torch_scatter import scatter_mean\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "def repeat_data(data: Data, num_repeat) -> Batch:\n", + " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n", + " return Batch.from_data_list(datas)\n", + "\n", + "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n", + " datas = batch.to_data_list()\n", + " new_data = []\n", + " for i in range(num_repeat):\n", + " new_data += copy.deepcopy(datas)\n", + " return Batch.from_data_list(new_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AMnQTk0eqT7Z" + }, + "source": [ + "#### Constants" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WYGkzqgzrHmF" + }, + "outputs": [], + "source": [ + "num_samples = 1 # solutions per molecule\n", + "num_molecules = 3\n", + "\n", + "DEVICE = 'cuda'\n", + "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n", + "# constants for inference\n", + "w_global = 0.5 #0,.3 for qm9\n", + "global_start_sigma = 0.5\n", + "eta = 1.0\n", + "clip_local = None\n", + "clip_pos = None\n", + "\n", + "# constands for data handling\n", + "save_traj = False\n", + "save_data = False\n", + "output_dir = '/content/'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-xD5bJ3SqM7t" + }, + "source": [ + "#### Generate samples!\n", + "Note that the 3d representation of a molecule is referred to as the **conformation**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "x9xuLUNg26z1", + "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " after removing the cwd from sys.path.\n", + "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n" + ] + } + ], + "source": [ + "results = []\n", + "\n", + "# define sigmas\n", + "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n", + "sigmas = sigmas.to(DEVICE)\n", + "\n", + "for count, data in enumerate(tqdm(dataset)):\n", + " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n", + "\n", + " data_input = data.clone()\n", + " data_input['pos_ref'] = None\n", + " batch = repeat_data(data_input, num_samples).to(DEVICE)\n", + "\n", + " # initial configuration\n", + " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n", + "\n", + " # for logging animation of denoising\n", + " pos_traj = []\n", + " with torch.no_grad():\n", + "\n", + " # scale initial sample\n", + " pos = pos_init * sigmas[-1]\n", + " for t in scheduler.timesteps:\n", + " batch.pos = pos\n", + "\n", + " # generate geometry with model, then filter it\n", + " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n", + "\n", + " # Update\n", + " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n", + "\n", + " pos = reconstructed_pos\n", + "\n", + " if torch.isnan(pos).any():\n", + " print(\"NaN detected. Please restart.\")\n", + " raise FloatingPointError()\n", + "\n", + " # recenter graph of positions for next iteration\n", + " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n", + "\n", + " # optional clipping\n", + " if clip_pos is not None:\n", + " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n", + " pos_traj.append(pos.clone().cpu())\n", + "\n", + " pos_gen = pos.cpu()\n", + " if save_traj:\n", + " pos_gen_traj = pos_traj.cpu()\n", + " data.pos_gen = torch.stack(pos_gen_traj)\n", + " else:\n", + " data.pos_gen = pos_gen\n", + " results.append(data)\n", + "\n", + "\n", + "if save_data:\n", + " save_path = os.path.join(output_dir, 'samples_all.pkl')\n", + "\n", + " with open(save_path, 'wb') as f:\n", + " pickle.dump(results, f)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fSApwSaZNndW" + }, + "source": [ + "## Render the results!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d47Zxo2OKdgZ" + }, + "source": [ + "This function allows us to render 3d in colab." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e9Cd0kCAv9b8" + }, + "outputs": [], + "source": [ + "from google.colab import output\n", + "\n", + "\n", + "output.enable_custom_widget_manager()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RjaVuR15NqzF" + }, + "source": [ + "### Helper functions" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "28rBYa9NKhlz" + }, + "source": [ + "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LKdKdwxcyTQ6" + }, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "\n", + "\n", + "def set_rdmol_positions(rdkit_mol, pos):\n", + " \"\"\"\n", + " Args:\n", + " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", + " pos: (N_atoms, 3)\n", + " \"\"\"\n", + " mol = deepcopy(rdkit_mol)\n", + " set_rdmol_positions_(mol, pos)\n", + " return mol\n", + "\n", + "def set_rdmol_positions_(mol, pos):\n", + " \"\"\"\n", + " Args:\n", + " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", + " pos: (N_atoms, 3)\n", + " \"\"\"\n", + " for i in range(pos.shape[0]):\n", + " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n", + " return mol\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NuE10hcpKmzK" + }, + "source": [ + "Process the generated data to make it easy to view." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KieVE1vc0_Vs", + "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "collect 5 generated molecules in `mols`\n" + ] + } + ], + "source": [ + "# the model can generate multiple conformations per 2d geometry\n", + "num_gen = results[0]['pos_gen'].shape[0]\n", + "\n", + "# init storage objects\n", + "mols_gen = []\n", + "mols_orig = []\n", + "for to_process in results:\n", + "\n", + " # store the reference 3d position\n", + " to_process['pos_ref'] = to_process['pos_ref'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", + "\n", + " # store the generated 3d position\n", + " to_process['pos_gen'] = to_process['pos_gen'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", + "\n", + " # copy data to new object\n", + " new_mol = set_rdmol_positions(to_process.rdmol, to_process['pos_gen'][0])\n", + "\n", + " # append results\n", + " mols_gen.append(new_mol)\n", + " mols_orig.append(to_process.rdmol)\n", + "\n", + "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tin89JwMKp4v" + }, + "source": [ + "Import tools to visualize the 2d chemical diagram of the molecule." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yqV6gllSZn38" + }, + "outputs": [], + "source": [ + "from IPython.display import SVG, display\n", + "from rdkit import Chem\n", + "from rdkit.Chem.Draw import rdMolDraw2D as MD2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TFNKmGddVoOk" + }, + "source": [ + "Select molecule to visualize" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KzuwLlrrVaGc" + }, + "outputs": [], + "source": [ + "idx = 0\n", + "assert idx < len(results), \"selected molecule that was not generated\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hkb8w0_SNtU8" + }, + "source": [ + "### Viewing" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I3R4QBQeKttN" + }, + "source": [ + "This 2D rendering is the equivalent of the **input to the model**!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 321 + }, + "id": "gkQRWjraaKex", + "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47" + }, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", + "text/plain": [ + "" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0CPv_NvehRz3", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n", + "molSize=(450,300)\n", + "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n", + "drawer.DrawMolecule(mc)\n", + "drawer.FinishDrawing()\n", + "svg = drawer.GetDrawingText()\n", + "display(SVG(svg.replace('svg:','')))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z4FDMYMxKw2I" + }, + "source": [ + "Generate the 3d molecule!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17, + "referenced_widgets": [ + "695ab5bbf30a4ab19df1f9f33469f314", + "eac6a8dcdc9d4335a2e51031793ead29" + ] + }, + "id": "aT1Bkb8YxJfV", + "outputId": "b98870ae-049d-4386-b676-166e9526bda2" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "695ab5bbf30a4ab19df1f9f33469f314", + "version_major": 2, + "version_minor": 0 }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting rdkit\n", - " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n", - "Installing collected packages: rdkit\n", - "Successfully installed rdkit-2022.3.5\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] + "text/plain": [] + }, + "metadata": { + "application/vnd.jupyter.widget-view+json": { + "colab": { + "custom_widget_manager": { + "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" } - ], - "source": [ - "!pip install rdkit" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "88GaDbDPxJ5I" + } + } + }, + "output_type": "display_data" + } + ], + "source": [ + "from nglview import show_rdkit as show" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 337, + "referenced_widgets": [ + "be446195da2b4ff2aec21ec5ff963a54", + "c6596896148b4a8a9c57963b67c7782f", + "2489b5e5648541fbbdceadb05632a050", + "01e0ba4e5da04914b4652b8d58565d7b", + "c30e6c2f3e2a44dbbb3d63bd519acaa4", + "f31c6e40e9b2466a9064a2669933ecd5", + "19308ccac642498ab8b58462e3f1b0bb", + "4a081cdc2ec3421ca79dd933b7e2b0c4", + "e5c0d75eb5e1447abd560c8f2c6017e1", + "5146907ef6764654ad7d598baebc8b58", + "144ec959b7604a2cabb5ca46ae5e5379", + "abce2a80e6304df3899109c6d6cac199", + "65195cb7a4134f4887e9dd19f3676462" + ] + }, + "id": "pxtq8I-I18C-", + "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "be446195da2b4ff2aec21ec5ff963a54", + "version_major": 2, + "version_minor": 0 }, - "source": [ - "### Get viewer from nglview\n", - "\n", - "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n", - "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n", - "The rdmol in this object is a source of ground truth for the generated molecules.\n", - "\n", - "You will use one rendering function from nglviewer later!\n", - "\n" + "text/plain": [ + "NGLWidget()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jcl8GCS2mz6t", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting nglview\n", - " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n", - "Collecting jupyterlab-widgets\n", - " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipywidgets>=7\n", - " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting widgetsnbextension~=4.0\n", - " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipython>=6.1.0\n", - " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipykernel>=4.5.1\n", - " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting traitlets>=4.3.1\n", - " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n", - "Collecting pyzmq>=17\n", - " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting matplotlib-inline>=0.1\n", - " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n", - "Collecting tornado>=6.1\n", - " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting nest-asyncio\n", - " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n", - "Collecting debugpy>=1.0\n", - " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting psutil\n", - " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting jupyter-client>=6.1.12\n", - " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting pickleshare\n", - " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n", - "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n", - "Collecting backcall\n", - " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n", - "Collecting pexpect>4.3\n", - " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting pygments\n", - " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting jedi>=0.16\n", - " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n", - " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n", - "Collecting parso<0.9.0,>=0.8.0\n", - " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n", - "Collecting entrypoints\n", - " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n", - "Collecting jupyter-core>=4.9.2\n", - " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ptyprocess>=0.5\n", - " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n", - "Collecting wcwidth\n", - " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n", - "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n", - "Building wheels for collected packages: nglview\n", - " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n", - " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n", - "Successfully built nglview\n", - "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n", - "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - }, - { - "output_type": "display_data", - "data": { - "application/vnd.colab-display-data+json": { - "pip_warning": { - "packages": [ - "pexpect", - "pickleshare", - "wcwidth" - ] - } - } - }, - "metadata": {} + }, + "metadata": { + "application/vnd.jupyter.widget-view+json": { + "colab": { + "custom_widget_manager": { + "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" } - ], - "source": [ - "!pip install nglview" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Create a diffusion model" - ], - "metadata": { - "id": "8t8_e_uVLdKB" + } } - }, - { - "cell_type": "markdown", - "source": [ - "### Model class(es)" - ], - "metadata": { - "id": "G0rMncVtNSqU" - } - }, - { - "cell_type": "markdown", - "source": [ - "Imports" - ], - "metadata": { - "id": "L5FEXz5oXkzt" - } - }, - { - "cell_type": "code", - "source": [ - "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n", - "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n", - "from dataclasses import dataclass\n", - "from typing import Callable, Tuple, Union\n", - "\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn.functional as F\n", - "from torch import Tensor, nn\n", - "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n", - "\n", - "from torch_geometric.nn import MessagePassing, radius, radius_graph\n", - "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n", - "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n", - "from torch_scatter import scatter_add\n", - "from torch_sparse import SparseTensor, coalesce\n", - "\n", - "from diffusers.configuration_utils import ConfigMixin, register_to_config\n", - "from diffusers.modeling_utils import ModelMixin\n", - "from diffusers.utils import BaseOutput\n" - ], - "metadata": { - "id": "-3-P4w5sXkRU" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Helper classes" - ], - "metadata": { - "id": "EzJQXPN_XrMX" - } - }, - { - "cell_type": "code", - "source": [ - "@dataclass\n", - "class MoleculeGNNOutput(BaseOutput):\n", - " \"\"\"\n", - " Args:\n", - " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n", - " Hidden states output. Output of last layer of model.\n", - " \"\"\"\n", - "\n", - " sample: torch.Tensor\n", - "\n", - "\n", - "class MultiLayerPerceptron(nn.Module):\n", - " \"\"\"\n", - " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n", - " Args:\n", - " input_dim (int): input dimension\n", - " hidden_dim (list of int): hidden dimensions\n", - " activation (str or function, optional): activation function\n", - " dropout (float, optional): dropout rate\n", - " \"\"\"\n", - "\n", - " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n", - " super(MultiLayerPerceptron, self).__init__()\n", - "\n", - " self.dims = [input_dim] + hidden_dims\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " print(f\"Warning, activation passed {activation} is not string and ignored\")\n", - " self.activation = None\n", - " if dropout > 0:\n", - " self.dropout = nn.Dropout(dropout)\n", - " else:\n", - " self.dropout = None\n", - "\n", - " self.layers = nn.ModuleList()\n", - " for i in range(len(self.dims) - 1):\n", - " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n", - "\n", - " def forward(self, x):\n", - " \"\"\"\"\"\"\n", - " for i, layer in enumerate(self.layers):\n", - " x = layer(x)\n", - " if i < len(self.layers) - 1:\n", - " if self.activation:\n", - " x = self.activation(x)\n", - " if self.dropout:\n", - " x = self.dropout(x)\n", - " return x\n", - "\n", - "\n", - "class ShiftedSoftplus(torch.nn.Module):\n", - " def __init__(self):\n", - " super(ShiftedSoftplus, self).__init__()\n", - " self.shift = torch.log(torch.tensor(2.0)).item()\n", - "\n", - " def forward(self, x):\n", - " return F.softplus(x) - self.shift\n", - "\n", - "\n", - "class CFConv(MessagePassing):\n", - " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n", - " super(CFConv, self).__init__(aggr=\"add\")\n", - " self.lin1 = Linear(in_channels, num_filters, bias=False)\n", - " self.lin2 = Linear(num_filters, out_channels)\n", - " self.nn = mlp\n", - " self.cutoff = cutoff\n", - " self.smooth = smooth\n", - "\n", - " self.reset_parameters()\n", - "\n", - " def reset_parameters(self):\n", - " torch.nn.init.xavier_uniform_(self.lin1.weight)\n", - " torch.nn.init.xavier_uniform_(self.lin2.weight)\n", - " self.lin2.bias.data.fill_(0)\n", - "\n", - " def forward(self, x, edge_index, edge_length, edge_attr):\n", - " if self.smooth:\n", - " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n", - " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n", - " else:\n", - " C = (edge_length <= self.cutoff).float()\n", - " W = self.nn(edge_attr) * C.view(-1, 1)\n", - "\n", - " x = self.lin1(x)\n", - " x = self.propagate(edge_index, x=x, W=W)\n", - " x = self.lin2(x)\n", - " return x\n", - "\n", - " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n", - " return x_j * W\n", - "\n", - "\n", - "class InteractionBlock(torch.nn.Module):\n", - " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n", - " super(InteractionBlock, self).__init__()\n", - " mlp = Sequential(\n", - " Linear(num_gaussians, num_filters),\n", - " ShiftedSoftplus(),\n", - " Linear(num_filters, num_filters),\n", - " )\n", - " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n", - " self.act = ShiftedSoftplus()\n", - " self.lin = Linear(hidden_channels, hidden_channels)\n", - "\n", - " def forward(self, x, edge_index, edge_length, edge_attr):\n", - " x = self.conv(x, edge_index, edge_length, edge_attr)\n", - " x = self.act(x)\n", - " x = self.lin(x)\n", - " return x\n", - "\n", - "\n", - "class SchNetEncoder(Module):\n", - " def __init__(\n", - " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n", - " ):\n", - " super().__init__()\n", - "\n", - " self.hidden_channels = hidden_channels\n", - " self.num_filters = num_filters\n", - " self.num_interactions = num_interactions\n", - " self.cutoff = cutoff\n", - "\n", - " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n", - "\n", - " self.interactions = ModuleList()\n", - " for _ in range(num_interactions):\n", - " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n", - " self.interactions.append(block)\n", - "\n", - " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n", - " if embed_node:\n", - " assert z.dim() == 1 and z.dtype == torch.long\n", - " h = self.embedding(z)\n", - " else:\n", - " h = z\n", - " for interaction in self.interactions:\n", - " h = h + interaction(h, edge_index, edge_length, edge_attr)\n", - "\n", - " return h\n", - "\n", - "\n", - "class GINEConv(MessagePassing):\n", - " \"\"\"\n", - " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n", - " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n", - " \"\"\"\n", - "\n", - " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n", - " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n", - " self.nn = mlp\n", - " self.initial_eps = eps\n", - "\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " self.activation = None\n", - "\n", - " if train_eps:\n", - " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n", - " else:\n", - " self.register_buffer(\"eps\", torch.Tensor([eps]))\n", - "\n", - " def forward(\n", - " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n", - " ) -> torch.Tensor:\n", - " \"\"\"\"\"\"\n", - " if isinstance(x, torch.Tensor):\n", - " x: OptPairTensor = (x, x)\n", - "\n", - " # Node and edge feature dimensionalites need to match.\n", - " if isinstance(edge_index, torch.Tensor):\n", - " assert edge_attr is not None\n", - " assert x[0].size(-1) == edge_attr.size(-1)\n", - " elif isinstance(edge_index, SparseTensor):\n", - " assert x[0].size(-1) == edge_index.size(-1)\n", - "\n", - " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n", - " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n", - "\n", - " x_r = x[1]\n", - " if x_r is not None:\n", - " out += (1 + self.eps) * x_r\n", - "\n", - " return self.nn(out)\n", - "\n", - " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n", - " if self.activation:\n", - " return self.activation(x_j + edge_attr)\n", - " else:\n", - " return x_j + edge_attr\n", - "\n", - " def __repr__(self):\n", - " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n", - "\n", - "\n", - "class GINEncoder(torch.nn.Module):\n", - " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n", - " super().__init__()\n", - "\n", - " self.hidden_dim = hidden_dim\n", - " self.num_convs = num_convs\n", - " self.short_cut = short_cut\n", - " self.concat_hidden = concat_hidden\n", - " self.node_emb = nn.Embedding(100, hidden_dim)\n", - "\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " self.activation = None\n", - "\n", - " self.convs = nn.ModuleList()\n", - " for i in range(self.num_convs):\n", - " self.convs.append(\n", - " GINEConv(\n", - " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n", - " activation=activation,\n", - " )\n", - " )\n", - "\n", - " def forward(self, z, edge_index, edge_attr):\n", - " \"\"\"\n", - " Input:\n", - " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n", - " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n", - " Output:\n", - " node_feature: graph feature\n", - " \"\"\"\n", - "\n", - " node_attr = self.node_emb(z) # (num_node, hidden)\n", - "\n", - " hiddens = []\n", - " conv_input = node_attr # (num_node, hidden)\n", - "\n", - " for conv_idx, conv in enumerate(self.convs):\n", - " hidden = conv(conv_input, edge_index, edge_attr)\n", - " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n", - " hidden = self.activation(hidden)\n", - " assert hidden.shape == conv_input.shape\n", - " if self.short_cut and hidden.shape == conv_input.shape:\n", - " hidden += conv_input\n", - "\n", - " hiddens.append(hidden)\n", - " conv_input = hidden\n", - "\n", - " if self.concat_hidden:\n", - " node_feature = torch.cat(hiddens, dim=-1)\n", - " else:\n", - " node_feature = hiddens[-1]\n", - "\n", - " return node_feature\n", - "\n", - "\n", - "class MLPEdgeEncoder(Module):\n", - " def __init__(self, hidden_dim=100, activation=\"relu\"):\n", - " super().__init__()\n", - " self.hidden_dim = hidden_dim\n", - " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n", - " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n", - "\n", - " @property\n", - " def out_channels(self):\n", - " return self.hidden_dim\n", - "\n", - " def forward(self, edge_length, edge_type):\n", - " \"\"\"\n", - " Input:\n", - " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n", - " Returns:\n", - " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n", - " \"\"\"\n", - " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n", - " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n", - " return d_emb * edge_attr # (num_edge, hidden)\n", - "\n", - "\n", - "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n", - " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n", - " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n", - " return h_pair\n", - "\n", - "\n", - "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n", - " \"\"\"\n", - " Args:\n", - " num_nodes: Number of atoms.\n", - " edge_index: Bond indices of the original graph.\n", - " edge_type: Bond types of the original graph.\n", - " order: Extension order.\n", - " Returns:\n", - " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n", - " \"\"\"\n", - "\n", - " def binarize(x):\n", - " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n", - "\n", - " def get_higher_order_adj_matrix(adj, order):\n", - " \"\"\"\n", - " Args:\n", - " adj: (N, N)\n", - " type_mat: (N, N)\n", - " Returns:\n", - " Following attributes will be updated:\n", - " - edge_index\n", - " - edge_type\n", - " Following attributes will be added to the data object:\n", - " - bond_edge_index: Original edge_index.\n", - " \"\"\"\n", - " adj_mats = [\n", - " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n", - " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n", - " ]\n", - "\n", - " for i in range(2, order + 1):\n", - " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n", - " order_mat = torch.zeros_like(adj)\n", - "\n", - " for i in range(1, order + 1):\n", - " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n", - "\n", - " return order_mat\n", - "\n", - " num_types = 22\n", - " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n", - " # from rdkit.Chem.rdchem import BondType as BT\n", - " N = num_nodes\n", - " adj = to_dense_adj(edge_index).squeeze(0)\n", - " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n", - "\n", - " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n", - " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n", - " assert (type_mat * type_highorder == 0).all()\n", - " type_new = type_mat + type_highorder\n", - "\n", - " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n", - " _, edge_order = dense_to_sparse(adj_order)\n", - "\n", - " # data.bond_edge_index = data.edge_index # Save original edges\n", - " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n", - "\n", - " return new_edge_index, new_edge_type\n", - "\n", - "\n", - "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n", - " assert edge_type.dim() == 1\n", - " N = pos.size(0)\n", - "\n", - " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n", - "\n", - " if is_sidechain is None:\n", - " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n", - " else:\n", - " # fetch sidechain and its batch index\n", - " is_sidechain = is_sidechain.bool()\n", - " dummy_index = torch.arange(pos.size(0), device=pos.device)\n", - " sidechain_pos = pos[is_sidechain]\n", - " sidechain_index = dummy_index[is_sidechain]\n", - " sidechain_batch = batch[is_sidechain]\n", - "\n", - " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n", - " r_edge_index_x = assign_index[1]\n", - " r_edge_index_y = assign_index[0]\n", - " r_edge_index_y = sidechain_index[r_edge_index_y]\n", - "\n", - " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n", - " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n", - " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n", - " # delete self loop\n", - " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n", - "\n", - " rgraph_adj = torch.sparse.LongTensor(\n", - " rgraph_edge_index,\n", - " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n", - " torch.Size([N, N]),\n", - " )\n", - "\n", - " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n", - "\n", - " new_edge_index = composed_adj.indices()\n", - " new_edge_type = composed_adj.values().long()\n", - "\n", - " return new_edge_index, new_edge_type\n", - "\n", - "\n", - "def extend_graph_order_radius(\n", - " num_nodes,\n", - " pos,\n", - " edge_index,\n", - " edge_type,\n", - " batch,\n", - " order=3,\n", - " cutoff=10.0,\n", - " extend_order=True,\n", - " extend_radius=True,\n", - " is_sidechain=None,\n", - "):\n", - " if extend_order:\n", - " edge_index, edge_type = _extend_graph_order(\n", - " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n", - " )\n", - "\n", - " if extend_radius:\n", - " edge_index, edge_type = _extend_to_radius_graph(\n", - " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n", - " )\n", - "\n", - " return edge_index, edge_type\n", - "\n", - "\n", - "def get_distance(pos, edge_index):\n", - " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n", - "\n", - "\n", - "def graph_field_network(score_d, pos, edge_index, edge_length):\n", - " \"\"\"\n", - " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n", - " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n", - " \"\"\"\n", - " N = pos.size(0)\n", - " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n", - " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n", - " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n", - " ) # (N, 3)\n", - " return score_pos\n", - "\n", - "\n", - "def clip_norm(vec, limit, p=2):\n", - " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n", - " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n", - " return vec * denom\n", - "\n", - "\n", - "def is_local_edge(edge_type):\n", - " return edge_type > 0\n" + }, + "output_type": "display_data" + } + ], + "source": [ + "# new molecule\n", + "show(mols_gen[idx])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KJr4h2mwXeTo" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "01e0ba4e5da04914b4652b8d58565d7b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1", + "IPY_MODEL_5146907ef6764654ad7d598baebc8b58" ], - "metadata": { - "id": "oR1Y56QiLY90" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "Main model class!" + "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379" + } + }, + "144ec959b7604a2cabb5ca46ae5e5379": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "19308ccac642498ab8b58462e3f1b0bb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1c6246f15b654f4daa11c9bcf997b78c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df", + "placeholder": "​", + "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c", + "value": "Downloading: 100%" + } + }, + "2489b5e5648541fbbdceadb05632a050": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ButtonModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ButtonView", + "button_style": "", + "description": "", + "disabled": false, + "icon": "compress", + "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199", + "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462", + "tooltip": "" + } + }, + "24d31fc3576e43dd9f8301d2ef3a37ab": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2918bfaadc8d4b1a9832522c40dfefb8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2c9362906e4b40189f16d14aa9a348da": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "42f6c3db29d7484ba6b4f73590abd2f4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "4a081cdc2ec3421ca79dd933b7e2b0c4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "SliderStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "SliderStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "", + "handle_color": null + } + }, + "5146907ef6764654ad7d598baebc8b58": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "IntSliderModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "IntSliderModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "IntSliderView", + "continuous_update": true, + "description": "", + "description_tooltip": null, + "disabled": false, + "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb", + "max": 0, + "min": 0, + "orientation": "horizontal", + "readout": true, + "readout_format": "d", + "step": 1, + "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4", + "value": 0 + } + }, + "561f742d418d4721b0670cc8dd62e22c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6010fc8daa7a44d5aec4b830ec2ebaa1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2", + "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", + "IPY_MODEL_6526646be5ed415c84d1245b040e629b" ], - "metadata": { - "id": "QWrHJFcYXyUB" - } - }, - { - "cell_type": "code", - "source": [ - "class MoleculeGNN(ModelMixin, ConfigMixin):\n", - " @register_to_config\n", - " def __init__(\n", - " self,\n", - " hidden_dim=128,\n", - " num_convs=6,\n", - " num_convs_local=4,\n", - " cutoff=10.0,\n", - " mlp_act=\"relu\",\n", - " edge_order=3,\n", - " edge_encoder=\"mlp\",\n", - " smooth_conv=True,\n", - " ):\n", - " super().__init__()\n", - " self.cutoff = cutoff\n", - " self.edge_encoder = edge_encoder\n", - " self.edge_order = edge_order\n", - "\n", - " \"\"\"\n", - " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n", - " in SchNetEncoder\n", - " \"\"\"\n", - " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", - " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", - "\n", - " \"\"\"\n", - " The graph neural network that extracts node-wise features.\n", - " \"\"\"\n", - " self.encoder_global = SchNetEncoder(\n", - " hidden_channels=hidden_dim,\n", - " num_filters=hidden_dim,\n", - " num_interactions=num_convs,\n", - " edge_channels=self.edge_encoder_global.out_channels,\n", - " cutoff=cutoff,\n", - " smooth=smooth_conv,\n", - " )\n", - " self.encoder_local = GINEncoder(\n", - " hidden_dim=hidden_dim,\n", - " num_convs=num_convs_local,\n", - " )\n", - "\n", - " \"\"\"\n", - " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n", - " gradients w.r.t. edge_length (out_dim = 1).\n", - " \"\"\"\n", - " self.grad_global_dist_mlp = MultiLayerPerceptron(\n", - " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", - " )\n", - "\n", - " self.grad_local_dist_mlp = MultiLayerPerceptron(\n", - " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", - " )\n", - "\n", - " \"\"\"\n", - " Incorporate parameters together\n", - " \"\"\"\n", - " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n", - " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n", - "\n", - " def _forward(\n", - " self,\n", - " atom_type,\n", - " pos,\n", - " bond_index,\n", - " bond_type,\n", - " batch,\n", - " time_step, # NOTE, model trained without timestep performed best\n", - " edge_index=None,\n", - " edge_type=None,\n", - " edge_length=None,\n", - " return_edges=False,\n", - " extend_order=True,\n", - " extend_radius=True,\n", - " is_sidechain=None,\n", - " ):\n", - " \"\"\"\n", - " Args:\n", - " atom_type: Types of atoms, (N, ).\n", - " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n", - " bond_type: Bond types, (E, ).\n", - " batch: Node index to graph index, (N, ).\n", - " \"\"\"\n", - " N = atom_type.size(0)\n", - " if edge_index is None or edge_type is None or edge_length is None:\n", - " edge_index, edge_type = extend_graph_order_radius(\n", - " num_nodes=N,\n", - " pos=pos,\n", - " edge_index=bond_index,\n", - " edge_type=bond_type,\n", - " batch=batch,\n", - " order=self.edge_order,\n", - " cutoff=self.cutoff,\n", - " extend_order=extend_order,\n", - " extend_radius=extend_radius,\n", - " is_sidechain=is_sidechain,\n", - " )\n", - " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n", - " local_edge_mask = is_local_edge(edge_type) # (E, )\n", - "\n", - " # with the parameterization of NCSNv2\n", - " # DDPM loss implicit handle the noise variance scale conditioning\n", - " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n", - "\n", - " # Encoding global\n", - " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", - "\n", - " # Global\n", - " node_attr_global = self.encoder_global(\n", - " z=atom_type,\n", - " edge_index=edge_index,\n", - " edge_length=edge_length,\n", - " edge_attr=edge_attr_global,\n", - " )\n", - " # Assemble pairwise features\n", - " h_pair_global = assemble_atom_pair_feature(\n", - " node_attr=node_attr_global,\n", - " edge_index=edge_index,\n", - " edge_attr=edge_attr_global,\n", - " ) # (E_global, 2H)\n", - " # Invariant features of edges (radius graph, global)\n", - " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n", - "\n", - " # Encoding local\n", - " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", - " # edge_attr += temb_edge\n", - "\n", - " # Local\n", - " node_attr_local = self.encoder_local(\n", - " z=atom_type,\n", - " edge_index=edge_index[:, local_edge_mask],\n", - " edge_attr=edge_attr_local[local_edge_mask],\n", - " )\n", - " # Assemble pairwise features\n", - " h_pair_local = assemble_atom_pair_feature(\n", - " node_attr=node_attr_local,\n", - " edge_index=edge_index[:, local_edge_mask],\n", - " edge_attr=edge_attr_local[local_edge_mask],\n", - " ) # (E_local, 2H)\n", - "\n", - " # Invariant features of edges (bond graph, local)\n", - " if isinstance(sigma_edge, torch.Tensor):\n", - " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n", - " 1.0 / sigma_edge[local_edge_mask]\n", - " ) # (E_local, 1)\n", - " else:\n", - " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n", - "\n", - " if return_edges:\n", - " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n", - " else:\n", - " return edge_inv_global, edge_inv_local\n", - "\n", - " def forward(\n", - " self,\n", - " sample,\n", - " timestep: Union[torch.Tensor, float, int],\n", - " return_dict: bool = True,\n", - " sigma=1.0,\n", - " global_start_sigma=0.5,\n", - " w_global=1.0,\n", - " extend_order=False,\n", - " extend_radius=True,\n", - " clip_local=None,\n", - " clip_global=1000.0,\n", - " ) -> Union[MoleculeGNNOutput, Tuple]:\n", - " r\"\"\"\n", - " Args:\n", - " sample: packed torch geometric object\n", - " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n", - " return_dict (`bool`, *optional*, defaults to `True`):\n", - " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n", - " Returns:\n", - " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n", - " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n", - " \"\"\"\n", - "\n", - " # unpack sample\n", - " atom_type = sample.atom_type\n", - " bond_index = sample.edge_index\n", - " bond_type = sample.edge_type\n", - " num_graphs = sample.num_graphs\n", - " pos = sample.pos\n", - "\n", - " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n", - "\n", - " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n", - " atom_type=atom_type,\n", - " pos=sample.pos,\n", - " bond_index=bond_index,\n", - " bond_type=bond_type,\n", - " batch=sample.batch,\n", - " time_step=timesteps,\n", - " return_edges=True,\n", - " extend_order=extend_order,\n", - " extend_radius=extend_radius,\n", - " ) # (E_global, 1), (E_local, 1)\n", - "\n", - " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n", - " node_eq_local = graph_field_network(\n", - " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n", - " )\n", - " if clip_local is not None:\n", - " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n", - "\n", - " # Global\n", - " if sigma < global_start_sigma:\n", - " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n", - " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n", - " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n", - " else:\n", - " node_eq_global = 0\n", - "\n", - " # Sum\n", - " eps_pos = node_eq_local + node_eq_global * w_global\n", - "\n", - " if not return_dict:\n", - " return (-eps_pos,)\n", - "\n", - " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))" + "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab" + } + }, + "65195cb7a4134f4887e9dd19f3676462": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ButtonStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "button_color": null, + "font_weight": "" + } + }, + "6526646be5ed415c84d1245b040e629b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433", + "placeholder": "​", + "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4", + "value": " 401/401 [00:00<00:00, 13.5kB/s]" + } + }, + "695ab5bbf30a4ab19df1f9f33469f314": { + "model_module": "nglview-js-widgets", + "model_module_version": "3.0.1", + "model_name": "ColormakerRegistryModel", + "state": { + "_dom_classes": [], + "_model_module": "nglview-js-widgets", + "_model_module_version": "3.0.1", + "_model_name": "ColormakerRegistryModel", + "_msg_ar": [], + "_msg_q": [], + "_ready": false, + "_view_count": null, + "_view_module": "nglview-js-widgets", + "_view_module_version": "3.0.1", + "_view_name": "ColormakerRegistryView", + "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29" + } + }, + "7e0bb1b8d65249d3974200686b193be2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8", + "placeholder": "​", + "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08", + "value": "Downloading: 100%" + } + }, + "872915dd1bb84f538c44e26badabafdd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a4bfdca35cc54dae8812720f1b276a08": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a9e388f22a9742aaaf538e22575c9433": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "abce2a80e6304df3899109c6d6cac199": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "34px" + } + }, + "b7feb522161f4cf4b7cc7c1a078ff12d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436", + "placeholder": "​", + "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da", + "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]" + } + }, + "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39", + "max": 401, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_f915cf874246446595206221e900b2fe", + "value": 401 + } + }, + "bbef741e76ec41b7ab7187b487a383df": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "be446195da2b4ff2aec21ec5ff963a54": { + "model_module": "nglview-js-widgets", + "model_module_version": "3.0.1", + "model_name": "NGLModel", + "state": { + "_camera_orientation": [ + -15.519693580202304, + -14.065056548036177, + -23.53197484807691, + 0, + -23.357853515109753, + 20.94055073042662, + 2.888695042134944, + 0, + 14.352363398292775, + 18.870825741878015, + -20.744689572909344, + 0, + 0.2724999189376831, + 0.6940000057220459, + -0.3734999895095825, + 1 ], - "metadata": { - "id": "MCeZA1qQXzoK" + "_camera_str": "orthographic", + "_dom_classes": [], + "_gui_theme": null, + "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050", + "_igui": null, + "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b", + "_model_module": "nglview-js-widgets", + "_model_module_version": "3.0.1", + "_model_name": "NGLModel", + "_ngl_color_dict": {}, + "_ngl_coordinate_resource": {}, + "_ngl_full_stage_parameters": { + "ambientColor": 14540253, + "ambientIntensity": 0.2, + "backgroundColor": "white", + "cameraEyeSep": 0.3, + "cameraFov": 40, + "cameraType": "perspective", + "clipDist": 10, + "clipFar": 100, + "clipNear": 0, + "fogFar": 100, + "fogNear": 50, + "hoverTimeout": 0, + "impostor": true, + "lightColor": 14540253, + "lightIntensity": 1, + "mousePreset": "default", + "panSpeed": 1, + "quality": "medium", + "rotateSpeed": 2, + "sampleLevel": 0, + "tooltip": true, + "workerDefault": true, + "zoomSpeed": 1.2 }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CCIrPYSJj9wd" - }, - "source": [ - "### Load pretrained model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YdrAr6Ch--Ab" - }, - "source": [ - "#### Load a model\n", - "The model used is a design an\n", - "equivariant convolutional layer, named graph field network (GFN).\n", - "\n", - "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "DyCo0nsqjbml", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 172, - "referenced_widgets": [ - "d90f304e9560472eacfbdd11e46765eb", - "1c6246f15b654f4daa11c9bcf997b78c", - "c2321b3bff6f490ca12040a20308f555", - "b7feb522161f4cf4b7cc7c1a078ff12d", - "e2d368556e494ae7ae4e2e992af2cd4f", - "bbef741e76ec41b7ab7187b487a383df", - "561f742d418d4721b0670cc8dd62e22c", - "872915dd1bb84f538c44e26badabafdd", - "d022575f1fa2446d891650897f187b4d", - "fdc393f3468c432aa0ada05e238a5436", - "2c9362906e4b40189f16d14aa9a348da", - "6010fc8daa7a44d5aec4b830ec2ebaa1", - "7e0bb1b8d65249d3974200686b193be2", - "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", - "6526646be5ed415c84d1245b040e629b", - "24d31fc3576e43dd9f8301d2ef3a37ab", - "2918bfaadc8d4b1a9832522c40dfefb8", - "a4bfdca35cc54dae8812720f1b276a08", - "e4901541199b45c6a18824627692fc39", - "f915cf874246446595206221e900b2fe", - "a9e388f22a9742aaaf538e22575c9433", - "42f6c3db29d7484ba6b4f73590abd2f4" - ] + "_ngl_msg_archive": [ + { + "args": [ + { + "binary": false, + "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n", + "type": "blob" + } + ], + "kwargs": { + "defaultRepresentation": true, + "ext": "pdb" }, - "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45" + "methodName": "loadFile", + "reconstruc_color_scheme": false, + "target": "Stage", + "type": "call_method" + } + ], + "_ngl_original_stage_parameters": { + "ambientColor": 14540253, + "ambientIntensity": 0.2, + "backgroundColor": "white", + "cameraEyeSep": 0.3, + "cameraFov": 40, + "cameraType": "perspective", + "clipDist": 10, + "clipFar": 100, + "clipNear": 0, + "fogFar": 100, + "fogNear": 50, + "hoverTimeout": 0, + "impostor": true, + "lightColor": 14540253, + "lightIntensity": 1, + "mousePreset": "default", + "panSpeed": 1, + "quality": "medium", + "rotateSpeed": 2, + "sampleLevel": 0, + "tooltip": true, + "workerDefault": true, + "zoomSpeed": 1.2 }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "Downloading: 0%| | 0.00/3.27M [00:00] 124.78K 180KB/s in 0.7s \n", - "\n", - "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n", - "\n" - ] + "metalness": 0, + "multipleBond": "off", + "opacity": 1, + "openEnded": true, + "quality": "high", + "radialSegments": 20, + "radiusData": {}, + "radiusScale": 2, + "radiusSize": 0.15, + "radiusType": "size", + "roughness": 0.4, + "sele": "", + "side": "double", + "sphereDetail": 2, + "useInteriorColor": true, + "visible": true, + "wireframe": false + }, + "type": "ball+stick" } - ], - "source": [ - "import torch\n", - "import numpy as np\n", - "\n", - "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n", - "dataset = torch.load('/content/molecules.pkl')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QZcmy1EvKQRk" - }, - "source": [ - "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JVjz6iH_H6Eh", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")" - ] + }, + "1": { + "0": { + "params": { + "aspectRatio": 1.5, + "assembly": "default", + "bondScale": 0.3, + "bondSpacing": 0.75, + "clipCenter": { + "x": 0, + "y": 0, + "z": 0 }, - "metadata": {}, - "execution_count": 20 - } - ], - "source": [ - "dataset[0]" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Run the diffusion process" - ], - "metadata": { - "id": "vHNiZAUxNgoy" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jZ1KZrxKqENg" - }, - "source": [ - "#### Helper Functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "s240tYueqKKf" - }, - "outputs": [], - "source": [ - "from torch_geometric.data import Data, Batch\n", - "from torch_scatter import scatter_add, scatter_mean\n", - "from tqdm import tqdm\n", - "import copy\n", - "import os\n", - "\n", - "def repeat_data(data: Data, num_repeat) -> Batch:\n", - " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n", - " return Batch.from_data_list(datas)\n", - "\n", - "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n", - " datas = batch.to_data_list()\n", - " new_data = []\n", - " for i in range(num_repeat):\n", - " new_data += copy.deepcopy(datas)\n", - " return Batch.from_data_list(new_data)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AMnQTk0eqT7Z" - }, - "source": [ - "#### Constants" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "WYGkzqgzrHmF" - }, - "outputs": [], - "source": [ - "num_samples = 1 # solutions per molecule\n", - "num_molecules = 3\n", - "\n", - "DEVICE = 'cuda'\n", - "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n", - "# constants for inference\n", - "w_global = 0.5 #0,.3 for qm9\n", - "global_start_sigma = 0.5\n", - "eta = 1.0\n", - "clip_local = None\n", - "clip_pos = None\n", - "\n", - "# constands for data handling\n", - "save_traj = False\n", - "save_data = False\n", - "output_dir = '/content/'" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-xD5bJ3SqM7t" - }, - "source": [ - "#### Generate samples!\n", - "Note that the 3d representation of a molecule is referred to as the **conformation**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "x9xuLUNg26z1", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " after removing the cwd from sys.path.\n", - "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n" - ] - } - ], - "source": [ - "results = []\n", - "\n", - "# define sigmas\n", - "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n", - "sigmas = sigmas.to(DEVICE)\n", - "\n", - "for count, data in enumerate(tqdm(dataset)):\n", - " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n", - "\n", - " data_input = data.clone()\n", - " data_input['pos_ref'] = None\n", - " batch = repeat_data(data_input, num_samples).to(DEVICE)\n", - "\n", - " # initial configuration\n", - " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n", - "\n", - " # for logging animation of denoising\n", - " pos_traj = []\n", - " with torch.no_grad():\n", - "\n", - " # scale initial sample\n", - " pos = pos_init * sigmas[-1]\n", - " for t in scheduler.timesteps:\n", - " batch.pos = pos\n", - "\n", - " # generate geometry with model, then filter it\n", - " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n", - "\n", - " # Update\n", - " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n", - "\n", - " pos = reconstructed_pos\n", - "\n", - " if torch.isnan(pos).any():\n", - " print(\"NaN detected. Please restart.\")\n", - " raise FloatingPointError()\n", - "\n", - " # recenter graph of positions for next iteration\n", - " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n", - "\n", - " # optional clipping\n", - " if clip_pos is not None:\n", - " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n", - " pos_traj.append(pos.clone().cpu())\n", - "\n", - " pos_gen = pos.cpu()\n", - " if save_traj:\n", - " pos_gen_traj = pos_traj.cpu()\n", - " data.pos_gen = torch.stack(pos_gen_traj)\n", - " else:\n", - " data.pos_gen = pos_gen\n", - " results.append(data)\n", - "\n", - "\n", - "if save_data:\n", - " save_path = os.path.join(output_dir, 'samples_all.pkl')\n", - "\n", - " with open(save_path, 'wb') as f:\n", - " pickle.dump(results, f)" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Render the results!" - ], - "metadata": { - "id": "fSApwSaZNndW" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d47Zxo2OKdgZ" - }, - "source": [ - "This function allows us to render 3d in colab." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "e9Cd0kCAv9b8" - }, - "outputs": [], - "source": [ - "from google.colab import output\n", - "output.enable_custom_widget_manager()" - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Helper functions" - ], - "metadata": { - "id": "RjaVuR15NqzF" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "28rBYa9NKhlz" - }, - "source": [ - "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "LKdKdwxcyTQ6" - }, - "outputs": [], - "source": [ - "from copy import deepcopy\n", - "def set_rdmol_positions(rdkit_mol, pos):\n", - " \"\"\"\n", - " Args:\n", - " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", - " pos: (N_atoms, 3)\n", - " \"\"\"\n", - " mol = deepcopy(rdkit_mol)\n", - " set_rdmol_positions_(mol, pos)\n", - " return mol\n", - "\n", - "def set_rdmol_positions_(mol, pos):\n", - " \"\"\"\n", - " Args:\n", - " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", - " pos: (N_atoms, 3)\n", - " \"\"\"\n", - " for i in range(pos.shape[0]):\n", - " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n", - " return mol\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NuE10hcpKmzK" - }, - "source": [ - "Process the generated data to make it easy to view." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KieVE1vc0_Vs", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "collect 5 generated molecules in `mols`\n" - ] - } - ], - "source": [ - "# the model can generate multiple conformations per 2d geometry\n", - "num_gen = results[0]['pos_gen'].shape[0]\n", - "\n", - "# init storage objects\n", - "mols_gen = []\n", - "mols_orig = []\n", - "for to_process in results:\n", - "\n", - " # store the reference 3d position\n", - " to_process['pos_ref'] = to_process['pos_ref'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", - "\n", - " # store the generated 3d position\n", - " to_process['pos_gen'] = to_process['pos_gen'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", - "\n", - " # copy data to new object\n", - " new_mol = set_rdmol_positions(to_process.rdmol, to_process['pos_gen'][0])\n", - "\n", - " # append results\n", - " mols_gen.append(new_mol)\n", - " mols_orig.append(to_process.rdmol)\n", - "\n", - "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tin89JwMKp4v" - }, - "source": [ - "Import tools to visualize the 2d chemical diagram of the molecule." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yqV6gllSZn38" - }, - "outputs": [], - "source": [ - "from rdkit.Chem import AllChem\n", - "from rdkit import Chem\n", - "from rdkit.Chem.Draw import rdMolDraw2D as MD2\n", - "from IPython.display import SVG, display" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TFNKmGddVoOk" - }, - "source": [ - "Select molecule to visualize" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KzuwLlrrVaGc" - }, - "outputs": [], - "source": [ - "idx = 0\n", - "assert idx < len(results), \"selected molecule that was not generated\"" - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Viewing" - ], - "metadata": { - "id": "hkb8w0_SNtU8" - } - }, - { - "cell_type": "markdown", - "metadata": { - "id": "I3R4QBQeKttN" - }, - "source": [ - "This 2D rendering is the equivalent of the **input to the model**!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gkQRWjraaKex", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 321 - }, - "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "" - ], - "image/svg+xml": "\n\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" + "clipNear": 0, + "clipRadius": 0, + "colorMode": "hcl", + "colorReverse": false, + "colorScale": "", + "colorScheme": "element", + "colorValue": 9474192, + "cylinderOnly": false, + "defaultAssembly": "", + "depthWrite": true, + "diffuse": 16777215, + "diffuseInterior": false, + "disableImpostor": false, + "disablePicking": false, + "flatShaded": false, + "interiorColor": 2236962, + "interiorDarkening": 0, + "lazy": false, + "lineOnly": false, + "linewidth": 2, + "matrix": { + "elements": [ + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1 + ] }, - "metadata": {} + "metalness": 0, + "multipleBond": "off", + "opacity": 1, + "openEnded": true, + "quality": "high", + "radialSegments": 20, + "radiusData": {}, + "radiusScale": 2, + "radiusSize": 0.15, + "radiusType": "size", + "roughness": 0.4, + "sele": "", + "side": "double", + "sphereDetail": 2, + "useInteriorColor": true, + "visible": true, + "wireframe": false + }, + "type": "ball+stick" } - ], - "source": [ - "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n", - "molSize=(450,300)\n", - "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n", - "drawer.DrawMolecule(mc)\n", - "drawer.FinishDrawing()\n", - "svg = drawer.GetDrawingText()\n", - "display(SVG(svg.replace('svg:','')))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "z4FDMYMxKw2I" + } }, - "source": [ - "Generate the 3d molecule!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "aT1Bkb8YxJfV", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 17, - "referenced_widgets": [ - "695ab5bbf30a4ab19df1f9f33469f314", - "eac6a8dcdc9d4335a2e51031793ead29" - ] - }, - "outputId": "b98870ae-049d-4386-b676-166e9526bda2" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [], - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "695ab5bbf30a4ab19df1f9f33469f314" - } - }, - "metadata": { - "application/vnd.jupyter.widget-view+json": { - "colab": { - "custom_widget_manager": { - "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" - } - } - } - } - } + "_ngl_serialize": false, + "_ngl_version": "", + "_ngl_view_id": [ + "FB989FD1-5B9C-446B-8914-6B58AF85446D" ], - "source": [ - "from nglview import show_rdkit as show" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "pxtq8I-I18C-", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 337, - "referenced_widgets": [ - "be446195da2b4ff2aec21ec5ff963a54", - "c6596896148b4a8a9c57963b67c7782f", - "2489b5e5648541fbbdceadb05632a050", - "01e0ba4e5da04914b4652b8d58565d7b", - "c30e6c2f3e2a44dbbb3d63bd519acaa4", - "f31c6e40e9b2466a9064a2669933ecd5", - "19308ccac642498ab8b58462e3f1b0bb", - "4a081cdc2ec3421ca79dd933b7e2b0c4", - "e5c0d75eb5e1447abd560c8f2c6017e1", - "5146907ef6764654ad7d598baebc8b58", - "144ec959b7604a2cabb5ca46ae5e5379", - "abce2a80e6304df3899109c6d6cac199", - "65195cb7a4134f4887e9dd19f3676462" - ] - }, - "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "NGLWidget()" - ], - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "be446195da2b4ff2aec21ec5ff963a54" - } - }, - "metadata": { - "application/vnd.jupyter.widget-view+json": { - "colab": { - "custom_widget_manager": { - "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" - } - } - } - } - } + "_player_dict": {}, + "_scene_position": {}, + "_scene_rotation": {}, + "_synced_model_ids": [], + "_synced_repr_model_ids": [], + "_view_count": null, + "_view_height": "", + "_view_module": "nglview-js-widgets", + "_view_module_version": "3.0.1", + "_view_name": "NGLView", + "_view_width": "", + "background": "white", + "frame": 0, + "gui_style": null, + "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f", + "max_frame": 0, + "n_components": 2, + "picked": {} + } + }, + "c2321b3bff6f490ca12040a20308f555": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd", + "max": 3271865, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d", + "value": 3271865 + } + }, + "c30e6c2f3e2a44dbbb3d63bd519acaa4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c6596896148b4a8a9c57963b67c7782f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d022575f1fa2446d891650897f187b4d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "d90f304e9560472eacfbdd11e46765eb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c", + "IPY_MODEL_c2321b3bff6f490ca12040a20308f555", + "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d" ], - "source": [ - "# new molecule\n", - "show(mols_gen[idx])" - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "KJr4h2mwXeTo" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "d90f304e9560472eacfbdd11e46765eb": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c", - "IPY_MODEL_c2321b3bff6f490ca12040a20308f555", - "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d" - ], - "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f" - } - }, - "1c6246f15b654f4daa11c9bcf997b78c": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df", - "placeholder": "​", - "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c", - "value": "Downloading: 100%" - } - }, - "c2321b3bff6f490ca12040a20308f555": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd", - "max": 3271865, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d", - "value": 3271865 - } - }, - "b7feb522161f4cf4b7cc7c1a078ff12d": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436", - "placeholder": "​", - "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da", - "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]" - } - }, - "e2d368556e494ae7ae4e2e992af2cd4f": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "bbef741e76ec41b7ab7187b487a383df": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "561f742d418d4721b0670cc8dd62e22c": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "872915dd1bb84f538c44e26badabafdd": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "d022575f1fa2446d891650897f187b4d": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "fdc393f3468c432aa0ada05e238a5436": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2c9362906e4b40189f16d14aa9a348da": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "6010fc8daa7a44d5aec4b830ec2ebaa1": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2", - "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", - "IPY_MODEL_6526646be5ed415c84d1245b040e629b" - ], - "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab" - } - }, - "7e0bb1b8d65249d3974200686b193be2": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8", - "placeholder": "​", - "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08", - "value": "Downloading: 100%" - } - }, - "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39", - "max": 401, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_f915cf874246446595206221e900b2fe", - "value": 401 - } - }, - "6526646be5ed415c84d1245b040e629b": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433", - "placeholder": "​", - "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4", - "value": " 401/401 [00:00<00:00, 13.5kB/s]" - } - }, - "24d31fc3576e43dd9f8301d2ef3a37ab": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2918bfaadc8d4b1a9832522c40dfefb8": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "a4bfdca35cc54dae8812720f1b276a08": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "e4901541199b45c6a18824627692fc39": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f915cf874246446595206221e900b2fe": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "a9e388f22a9742aaaf538e22575c9433": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "42f6c3db29d7484ba6b4f73590abd2f4": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "695ab5bbf30a4ab19df1f9f33469f314": { - "model_module": "nglview-js-widgets", - "model_name": "ColormakerRegistryModel", - "model_module_version": "3.0.1", - "state": { - "_dom_classes": [], - "_model_module": "nglview-js-widgets", - "_model_module_version": "3.0.1", - "_model_name": "ColormakerRegistryModel", - "_msg_ar": [], - "_msg_q": [], - "_ready": false, - "_view_count": null, - "_view_module": "nglview-js-widgets", - "_view_module_version": "3.0.1", - "_view_name": "ColormakerRegistryView", - "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29" - } - }, - "eac6a8dcdc9d4335a2e51031793ead29": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "be446195da2b4ff2aec21ec5ff963a54": { - "model_module": "nglview-js-widgets", - "model_name": "NGLModel", - "model_module_version": "3.0.1", - "state": { - "_camera_orientation": [ - -15.519693580202304, - -14.065056548036177, - -23.53197484807691, - 0, - -23.357853515109753, - 20.94055073042662, - 2.888695042134944, - 0, - 14.352363398292777, - 18.870825741878015, - -20.744689572909344, - 0, - 0.2724999189376831, - 0.6940000057220459, - -0.3734999895095825, - 1 - ], - "_camera_str": "orthographic", - "_dom_classes": [], - "_gui_theme": null, - "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050", - "_igui": null, - "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b", - "_model_module": "nglview-js-widgets", - "_model_module_version": "3.0.1", - "_model_name": "NGLModel", - "_ngl_color_dict": {}, - "_ngl_coordinate_resource": {}, - "_ngl_full_stage_parameters": { - "impostor": true, - "quality": "medium", - "workerDefault": true, - "sampleLevel": 0, - "backgroundColor": "white", - "rotateSpeed": 2, - "zoomSpeed": 1.2, - "panSpeed": 1, - "clipNear": 0, - "clipFar": 100, - "clipDist": 10, - "fogNear": 50, - "fogFar": 100, - "cameraFov": 40, - "cameraEyeSep": 0.3, - "cameraType": "perspective", - "lightColor": 14540253, - "lightIntensity": 1, - "ambientColor": 14540253, - "ambientIntensity": 0.2, - "hoverTimeout": 0, - "tooltip": true, - "mousePreset": "default" - }, - "_ngl_msg_archive": [ - { - "target": "Stage", - "type": "call_method", - "methodName": "loadFile", - "reconstruc_color_scheme": false, - "args": [ - { - "type": "blob", - "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n", - "binary": false - } - ], - "kwargs": { - "defaultRepresentation": true, - "ext": "pdb" - } - } - ], - "_ngl_original_stage_parameters": { - "impostor": true, - "quality": "medium", - "workerDefault": true, - "sampleLevel": 0, - "backgroundColor": "white", - "rotateSpeed": 2, - "zoomSpeed": 1.2, - "panSpeed": 1, - "clipNear": 0, - "clipFar": 100, - "clipDist": 10, - "fogNear": 50, - "fogFar": 100, - "cameraFov": 40, - "cameraEyeSep": 0.3, - "cameraType": "perspective", - "lightColor": 14540253, - "lightIntensity": 1, - "ambientColor": 14540253, - "ambientIntensity": 0.2, - "hoverTimeout": 0, - "tooltip": true, - "mousePreset": "default" - }, - "_ngl_repr_dict": { - "0": { - "0": { - "type": "ball+stick", - "params": { - "lazy": false, - "visible": true, - "quality": "high", - "sphereDetail": 2, - "radialSegments": 20, - "openEnded": true, - "disableImpostor": false, - "aspectRatio": 1.5, - "lineOnly": false, - "cylinderOnly": false, - "multipleBond": "off", - "bondScale": 0.3, - "bondSpacing": 0.75, - "linewidth": 2, - "radiusType": "size", - "radiusData": {}, - "radiusSize": 0.15, - "radiusScale": 2, - "assembly": "default", - "defaultAssembly": "", - "clipNear": 0, - "clipRadius": 0, - "clipCenter": { - "x": 0, - "y": 0, - "z": 0 - }, - "flatShaded": false, - "opacity": 1, - "depthWrite": true, - "side": "double", - "wireframe": false, - "colorScheme": "element", - "colorScale": "", - "colorReverse": false, - "colorValue": 9474192, - "colorMode": "hcl", - "roughness": 0.4, - "metalness": 0, - "diffuse": 16777215, - "diffuseInterior": false, - "useInteriorColor": true, - "interiorColor": 2236962, - "interiorDarkening": 0, - "matrix": { - "elements": [ - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1 - ] - }, - "disablePicking": false, - "sele": "" - } - } - }, - "1": { - "0": { - "type": "ball+stick", - "params": { - "lazy": false, - "visible": true, - "quality": "high", - "sphereDetail": 2, - "radialSegments": 20, - "openEnded": true, - "disableImpostor": false, - "aspectRatio": 1.5, - "lineOnly": false, - "cylinderOnly": false, - "multipleBond": "off", - "bondScale": 0.3, - "bondSpacing": 0.75, - "linewidth": 2, - "radiusType": "size", - "radiusData": {}, - "radiusSize": 0.15, - "radiusScale": 2, - "assembly": "default", - "defaultAssembly": "", - "clipNear": 0, - "clipRadius": 0, - "clipCenter": { - "x": 0, - "y": 0, - "z": 0 - }, - "flatShaded": false, - "opacity": 1, - "depthWrite": true, - "side": "double", - "wireframe": false, - "colorScheme": "element", - "colorScale": "", - "colorReverse": false, - "colorValue": 9474192, - "colorMode": "hcl", - "roughness": 0.4, - "metalness": 0, - "diffuse": 16777215, - "diffuseInterior": false, - "useInteriorColor": true, - "interiorColor": 2236962, - "interiorDarkening": 0, - "matrix": { - "elements": [ - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1 - ] - }, - "disablePicking": false, - "sele": "" - } - } - } - }, - "_ngl_serialize": false, - "_ngl_version": "", - "_ngl_view_id": [ - "FB989FD1-5B9C-446B-8914-6B58AF85446D" - ], - "_player_dict": {}, - "_scene_position": {}, - "_scene_rotation": {}, - "_synced_model_ids": [], - "_synced_repr_model_ids": [], - "_view_count": null, - "_view_height": "", - "_view_module": "nglview-js-widgets", - "_view_module_version": "3.0.1", - "_view_name": "NGLView", - "_view_width": "", - "background": "white", - "frame": 0, - "gui_style": null, - "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f", - "max_frame": 0, - "n_components": 2, - "picked": {} - } - }, - "c6596896148b4a8a9c57963b67c7782f": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2489b5e5648541fbbdceadb05632a050": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ButtonModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ButtonModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ButtonView", - "button_style": "", - "description": "", - "disabled": false, - "icon": "compress", - "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199", - "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462", - "tooltip": "" - } - }, - "01e0ba4e5da04914b4652b8d58565d7b": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1", - "IPY_MODEL_5146907ef6764654ad7d598baebc8b58" - ], - "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379" - } - }, - "c30e6c2f3e2a44dbbb3d63bd519acaa4": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f31c6e40e9b2466a9064a2669933ecd5": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "19308ccac642498ab8b58462e3f1b0bb": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "4a081cdc2ec3421ca79dd933b7e2b0c4": { - "model_module": "@jupyter-widgets/controls", - "model_name": "SliderStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "SliderStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "", - "handle_color": null - } - }, - "e5c0d75eb5e1447abd560c8f2c6017e1": { - "model_module": "@jupyter-widgets/controls", - "model_name": "PlayModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "PlayModel", - "_playing": false, - "_repeat": false, - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "PlayView", - "description": "", - "description_tooltip": null, - "disabled": false, - "interval": 100, - "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4", - "max": 0, - "min": 0, - "show_repeat": true, - "step": 1, - "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5", - "value": 0 - } - }, - "5146907ef6764654ad7d598baebc8b58": { - "model_module": "@jupyter-widgets/controls", - "model_name": "IntSliderModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "IntSliderModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "IntSliderView", - "continuous_update": true, - "description": "", - "description_tooltip": null, - "disabled": false, - "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb", - "max": 0, - "min": 0, - "orientation": "horizontal", - "readout": true, - "readout_format": "d", - "step": 1, - "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4", - "value": 0 - } - }, - "144ec959b7604a2cabb5ca46ae5e5379": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "abce2a80e6304df3899109c6d6cac199": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": "34px" - } - }, - "65195cb7a4134f4887e9dd19f3676462": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ButtonStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ButtonStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "button_color": null, - "font_weight": "" - } - } - } + "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f" + } + }, + "e2d368556e494ae7ae4e2e992af2cd4f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e4901541199b45c6a18824627692fc39": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e5c0d75eb5e1447abd560c8f2c6017e1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "PlayModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "PlayModel", + "_playing": false, + "_repeat": false, + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "PlayView", + "description": "", + "description_tooltip": null, + "disabled": false, + "interval": 100, + "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4", + "max": 0, + "min": 0, + "show_repeat": true, + "step": 1, + "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5", + "value": 0 + } + }, + "eac6a8dcdc9d4335a2e51031793ead29": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f31c6e40e9b2466a9064a2669933ecd5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "f915cf874246446595206221e900b2fe": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "fdc393f3468c432aa0ada05e238a5436": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } } - }, - "nbformat": 4, - "nbformat_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 } \ No newline at end of file diff --git a/examples/research_projects/gligen/demo.ipynb b/examples/research_projects/gligen/demo.ipynb index 571f1a0323a2..4930253ff66e 100644 --- a/examples/research_projects/gligen/demo.ipynb +++ b/examples/research_projects/gligen/demo.ipynb @@ -26,8 +26,7 @@ "%load_ext autoreload\n", "%autoreload 2\n", "\n", - "import torch\n", - "from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline" + "from diffusers import StableDiffusionGLIGENPipeline" ] }, { @@ -36,16 +35,17 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", + "from transformers import CLIPTextModel, CLIPTokenizer\n", + "\n", "import diffusers\n", "from diffusers import (\n", " AutoencoderKL,\n", " DDPMScheduler,\n", - " UNet2DConditionModel,\n", - " UniPCMultistepScheduler,\n", " EulerDiscreteScheduler,\n", + " UNet2DConditionModel,\n", ")\n", - "from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n", + "\n", + "\n", "# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n", "\n", "pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n", @@ -122,6 +122,7 @@ "\n", "import numpy as np\n", "\n", + "\n", "boxes = np.array([x[1] for x in gen_boxes])\n", "boxes = boxes / 512\n", "boxes[:, 2] = boxes[:, 0] + boxes[:, 2]\n", diff --git a/scripts/convert_omnigen_to_diffusers.py b/scripts/convert_omnigen_to_diffusers.py index 4e8b284d2fb7..6893644f5ce2 100644 --- a/scripts/convert_omnigen_to_diffusers.py +++ b/scripts/convert_omnigen_to_diffusers.py @@ -6,7 +6,7 @@ from safetensors.torch import load_file from transformers import AutoTokenizer -from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenTransformer2DModel, OmniGenPipeline +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel def main(args): diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 63826baf6f78..5d97fbd61969 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -429,7 +429,7 @@ def forward(self, if isinstance(latent, list): if padding_latent is None: padding_latent = [None] * len(latent) - patched_latents, num_tokens, shapes = [], [], [] + patched_latents = [] for sub_latent, padding in zip(latent, padding_latent): height, width = sub_latent.shape[-2:] sub_latent = self.patch_embeddings(sub_latent, is_input_image) diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index 3a1072a0d349..cea926b0531e 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -13,12 +13,12 @@ # limitations under the License. from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union, List +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from transformers import Phi3Model, Phi3Config +from transformers import Phi3Config, Phi3Model from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast @@ -120,7 +120,7 @@ def forward( ) use_cache = False - # kept for BC (non `Cache` `past_key_values` inputs) + # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): return_legacy_cache = True diff --git a/src/diffusers/pipelines/omnigen/kvcache_omnigen.py b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py index 4bf32ae6ae74..34d98d939df2 100644 --- a/src/diffusers/pipelines/omnigen/kvcache_omnigen.py +++ b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Dict, Any, Tuple, List +from typing import Any, Dict, List, Optional, Tuple import torch from transformers.cache_utils import DynamicCache @@ -61,7 +61,6 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: torch.cuda.current_stream().synchronize() self.evict_previous_layer(layer_idx) # Load current layer cache to its original device if not already there - original_device = self.original_device[layer_idx] # self.prefetch_stream.synchronize(original_device) torch.cuda.synchronize(self.prefetch_stream) key_tensor = self.key_cache[layer_idx] diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 44fed3490843..0bb61e905bbf 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -15,8 +15,8 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union -import PIL import numpy as np +import PIL import torch from transformers import LlamaTokenizer @@ -35,6 +35,7 @@ from .kvcache_omnigen import OmniGenCache from .processor_omnigen import OmniGenMultiModalProcessor + if is_torch_xla_available(): XLA_AVAILABLE = True @@ -228,7 +229,7 @@ def check_inputs( if use_kv_cache and offload_kv_cache: if not torch.cuda.is_available(): raise ValueError( - f"Don't fine avaliable GPUs. `offload_kv_cache` can't be used when there is no GPU. please set it to False: `use_kv_cache=False, offload_kv_cache=False`" + "Don't fine avaliable GPUs. `offload_kv_cache` can't be used when there is no GPU. please set it to False: `use_kv_cache=False, offload_kv_cache=False`" ) if callback_on_step_end_tensor_inputs is not None and not all( diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py index af52b20b55db..b350933d995c 100644 --- a/src/diffusers/pipelines/omnigen/processor_omnigen.py +++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py @@ -90,7 +90,7 @@ def process_multi_modal_prompt(self, text, input_images): image_tags = re.findall(pattern, text) image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags] - unique_image_ids = sorted(list(set(image_ids))) + unique_image_ids = sorted(set(image_ids)) assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}" # total images must be the same as the number of image tags @@ -101,7 +101,6 @@ def process_multi_modal_prompt(self, text, input_images): all_input_ids = [] img_inx = [] - idx = 0 for i in range(len(prompt_chunks)): all_input_ids.extend(prompt_chunks[i]) if i != len(prompt_chunks) - 1: @@ -176,8 +175,7 @@ def create_position(self, attention_mask, num_tokens_for_output_images): img_length = max(num_tokens_for_output_images) for mask in attention_mask: temp_l = torch.sum(mask) - temp_position = [0] * (text_length - temp_l) + [i for i in range( - temp_l + img_length + 1)] # we add a time embedding into the sequence, so add one more token + temp_position = [0] * (text_length - temp_l) + list(range(temp_l + img_length + 1)) # we add a time embedding into the sequence, so add one more token position_ids.append(temp_position) return torch.LongTensor(position_ids) diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py index 93870f9da31d..3d883fa9011d 100644 --- a/tests/pipelines/omnigen/test_pipeline_omnigen.py +++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py @@ -3,9 +3,9 @@ import numpy as np import torch -from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM +from transformers import AutoTokenizer -from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenTransformer2DModel, OmniGenPipeline +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel from diffusers.utils.testing_utils import ( numpy_cosine_similarity_distance, require_torch_gpu, From 3bb092b145bc85bd2e1469a0ee477f45b68910bc Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Sun, 8 Dec 2024 19:54:21 +0800 Subject: [PATCH 19/55] make style --- .../en/using-diffusers/multimodal2img.md | 4 +- docs/source/en/using-diffusers/omnigen.md | 8 +- scripts/convert_omnigen_to_diffusers.py | 42 ++-- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/embeddings.py | 38 ++-- .../transformers/transformer_omnigen.py | 132 +++++++------ .../pipelines/omnigen/kvcache_omnigen.py | 22 +-- .../pipelines/omnigen/pipeline_omnigen.py | 183 +++++++++--------- .../pipelines/omnigen/processor_omnigen.py | 125 ++++++------ src/diffusers/utils/dummy_pt_objects.py | 1 + .../omnigen/test_pipeline_omnigen.py | 45 ++--- 11 files changed, 311 insertions(+), 291 deletions(-) diff --git a/docs/source/en/using-diffusers/multimodal2img.md b/docs/source/en/using-diffusers/multimodal2img.md index 1aabb99d5879..0c0bbc01ee9f 100644 --- a/docs/source/en/using-diffusers/multimodal2img.md +++ b/docs/source/en/using-diffusers/multimodal2img.md @@ -79,8 +79,8 @@ pipe.to("cuda") prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <|image_1|>. The long-sleeve blouse and a pleated skirt are <|image_2|>." -input_image_1 = load_image("/share/junjie/code/VISTA2/produce_data/laion_net/diffgpt/OmniGen/docs_img/emma.jpeg") -input_image_2 = load_image("/share/junjie/code/VISTA2/produce_data/laion_net/diffgpt/OmniGen/docs_img/dress.jpg") +input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg") +input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg") input_images=[input_image_1, input_image_2] image = pipe( prompt=prompt, diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md index 833199641644..df85fba74395 100644 --- a/docs/source/en/using-diffusers/omnigen.md +++ b/docs/source/en/using-diffusers/omnigen.md @@ -221,8 +221,8 @@ pipe = OmniGenPipeline.from_pretrained( pipe.to("cuda") prompt="A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <|image_1|>. The woman is the woman on the left of <|image_2|>" -input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.jpg") -input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/4.jpg") +input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.png") +input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/4.png") input_images=[input_image_1, input_image_2] image = pipe( prompt=prompt, @@ -263,8 +263,8 @@ pipe.to("cuda") prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <|image_1|>. The long-sleeve blouse and a pleated skirt are <|image_2|>." -input_image_1 = load_image("/share/junjie/code/VISTA2/produce_data/laion_net/diffgpt/OmniGen/docs_img/emma.jpeg") -input_image_2 = load_image("/share/junjie/code/VISTA2/produce_data/laion_net/diffgpt/OmniGen/docs_img/dress.jpg") +input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg") +input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg") input_images=[input_image_1, input_image_2] image = pipe( prompt=prompt, diff --git a/scripts/convert_omnigen_to_diffusers.py b/scripts/convert_omnigen_to_diffusers.py index 6893644f5ce2..cfa46c1afb0e 100644 --- a/scripts/convert_omnigen_to_diffusers.py +++ b/scripts/convert_omnigen_to_diffusers.py @@ -14,14 +14,15 @@ def main(args): if not os.path.exists(args.origin_ckpt_path): print("Model not found, downloading...") - cache_folder = os.getenv('HF_HUB_CACHE') - args.origin_ckpt_path = snapshot_download(repo_id=args.origin_ckpt_path, - cache_dir=cache_folder, - ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', - 'model.pt']) + cache_folder = os.getenv("HF_HUB_CACHE") + args.origin_ckpt_path = snapshot_download( + repo_id=args.origin_ckpt_path, + cache_dir=cache_folder, + ignore_patterns=["flax_model.msgpack", "rust_model.ot", "tf_model.h5", "model.pt"], + ) print(f"Downloaded model to {args.origin_ckpt_path}") - ckpt = os.path.join(args.origin_ckpt_path, 'model.safetensors') + ckpt = os.path.join(args.origin_ckpt_path, "model.safetensors") ckpt = load_file(ckpt, device="cpu") mapping_dict = { @@ -34,7 +35,6 @@ def main(args): "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias", "final_layer.linear.weight": "proj_out.weight", "final_layer.linear.bias": "proj_out.bias", - } converted_state_dict = {} @@ -50,9 +50,7 @@ def main(args): transformer_config = { "_name_or_path": "Phi-3-vision-128k-instruct", - "architectures": [ - "Phi3ForCausalLM" - ], + "architectures": ["Phi3ForCausalLM"], "attention_dropout": 0.0, "bos_token_id": 1, "eos_token_id": 2, @@ -116,7 +114,7 @@ def main(args): 64.760009765625, 64.80001068115234, 64.81001281738281, - 64.81001281738281 + 64.81001281738281, ], "short_factor": [ 1.05, @@ -166,9 +164,9 @@ def main(args): 2.9499999999999975, 3.049999999999997, 3.049999999999997, - 3.049999999999997 + 3.049999999999997, ], - "type": "su" + "type": "su", }, "rope_theta": 10000.0, "sliding_window": 131072, @@ -177,7 +175,7 @@ def main(args): "transformers_version": "4.38.1", "use_cache": True, "vocab_size": 32064, - "_attn_implementation": "sdpa" + "_attn_implementation": "sdpa", } transformer = OmniGenTransformer2DModel( transformer_config=transformer_config, @@ -197,9 +195,7 @@ def main(args): tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path) - pipeline = OmniGenPipeline( - tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler - ) + pipeline = OmniGenPipeline(tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler) pipeline.save_pretrained(args.dump_path) @@ -207,12 +203,16 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument( - "--origin_ckpt_path", default="Shitao/OmniGen-v1", type=str, required=False, - help="Path to the checkpoint to convert." + "--origin_ckpt_path", + default="Shitao/OmniGen-v1", + type=str, + required=False, + help="Path to the checkpoint to convert.", ) - parser.add_argument("--dump_path", default="OmniGen-v1-diffusers", type=str, required=False, - help="Path to the output pipeline.") + parser.add_argument( + "--dump_path", default="OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline." + ) args = parser.parse_args() main(args) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index d1c3070df5b4..6311f702ceed 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -65,9 +65,9 @@ _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] + _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] - _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 5d97fbd61969..c5c65e509274 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -349,19 +349,18 @@ def forward(self, latent): return (latent + pos_embed).to(latent.dtype) - class OmniGenPatchEmbed(nn.Module): """2D Image to Patch Embedding with support for OmniGen.""" def __init__( self, - patch_size: int =2, - in_channels: int =4, - embed_dim: int =768, - bias: bool =True, - interpolation_scale: float =1, - pos_embed_max_size: int =192, - base_size: int =64, + patch_size: int = 2, + in_channels: int = 4, + embed_dim: int = 768, + bias: bool = True, + interpolation_scale: float = 1, + pos_embed_max_size: int = 192, + base_size: int = 64, ): super().__init__() @@ -412,16 +411,14 @@ def patch_embeddings(self, latent, is_input_image: bool): latent = latent.flatten(2).transpose(1, 2) return latent - def forward(self, - latent: torch.Tensor, - is_input_image: bool, - padding_latent: torch.Tensor = None - ): + def forward(self, latent: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None): """ Args: latent: encoded image latents is_input_image: use input_image_proj or output_image_proj - padding_latent: When sizes of target images are inconsistent, use `padding_latent` to maintain consistent sequence length. + padding_latent: + When sizes of target images are inconsistent, use `padding_latent` to maintain consistent sequence + length. Returns: torch.Tensor @@ -1155,18 +1152,16 @@ def __init__(self, hidden_size, frequency_embedding_size=256): @staticmethod def timestep_embedding(t, dim, max_period=10000): """ - Create sinusoidal timestep embeddings. - :param t: a 1-D Tensor of N indices, one per batch element. + Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. + :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=t.device) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=t.device + ) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: @@ -1179,7 +1174,6 @@ def forward(self, t, dtype=torch.float32): return t_emb - class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index cea926b0531e..209d6b27aaa6 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -53,8 +53,8 @@ class OmniGen2DModelOutput(Transformer2DModelOutput): class OmniGenBaseTransformer(Phi3Model): """ - Transformer used in OmniGen. The transformer block is from Ph3, and only modify the attention mask. - References: [OmniGen](https://arxiv.org/pdf/2409.11340) + Transformer used in OmniGen. The transformer block is from Ph3, and only modify the attention mask. References: + [OmniGen](https://arxiv.org/pdf/2409.11340) Parameters: config: Phi3Config @@ -89,18 +89,18 @@ def get_offload_layer(self, layer_idx: int, device: torch.device): self.prefetch_layer((layer_idx + 1) % len(self.layers), device) def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - offload_transformer_block: Optional[bool] = False, + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + offload_transformer_block: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -225,15 +225,16 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): in_channels (`int`, *optional*, defaults to 4): The number of channels in the input. pos_embed_max_size (`int`, *optional*, defaults to 192): The max size of pos emb. """ + _supports_gradient_checkpointing = True @register_to_config def __init__( - self, - transformer_config: Dict, - patch_size=2, - in_channels=4, - pos_embed_max_size: int = 192, + self, + transformer_config: Dict, + patch_size=2, + in_channels=4, + pos_embed_max_size: int = 192, ): super().__init__() self.in_channels = in_channels @@ -244,10 +245,12 @@ def __init__( transformer_config = Phi3Config(**transformer_config) hidden_size = transformer_config.hidden_size - self.patch_embedding = OmniGenPatchEmbed(patch_size=patch_size, - in_channels=in_channels, - embed_dim=hidden_size, - pos_embed_max_size=pos_embed_max_size) + self.patch_embedding = OmniGenPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=hidden_size, + pos_embed_max_size=pos_embed_max_size, + ) self.time_token = OmniGenTimestepEmbed(hidden_size) self.t_embedder = OmniGenTimestepEmbed(hidden_size) @@ -260,14 +263,14 @@ def __init__( def unpatchify(self, x, h, w): """ - x: (N, T, patch_size**2 * C) - imgs: (N, H, W, C) + x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) """ c = self.out_channels x = x.reshape( - shape=(x.shape[0], h // self.patch_size, w // self.patch_size, self.patch_size, self.patch_size, c)) - x = torch.einsum('nhwpqc->nchpwq', x) + shape=(x.shape[0], h // self.patch_size, w // self.patch_size, self.patch_size, self.patch_size, c) + ) + x = torch.einsum("nhwpqc->nchpwq", x) imgs = x.reshape(shape=(x.shape[0], c, h, w)) return imgs @@ -335,13 +338,15 @@ def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value - def get_multimodal_embeddings(self, - input_ids: torch.Tensor, - input_img_latents: List[torch.Tensor], - input_image_sizes: Dict, - ): + def get_multimodal_embeddings( + self, + input_ids: torch.Tensor, + input_img_latents: List[torch.Tensor], + input_image_sizes: Dict, + ): """ get the multi-modal conditional embeddings + Args: input_ids: a sequence of text id input_img_latents: continues embedding of input images @@ -356,31 +361,32 @@ def get_multimodal_embeddings(self, condition_tokens = self.llm.embed_tokens(input_ids) input_img_inx = 0 if input_img_latents is not None: - input_image_tokens = self.patch_embedding(input_img_latents, - is_input_image=True) + input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True) for b_inx in input_image_sizes.keys(): for start_inx, end_inx in input_image_sizes[b_inx]: # replace the placeholder in text tokens with the image embedding. - condition_tokens[b_inx, start_inx: end_inx] = input_image_tokens[input_img_inx].to( - condition_tokens.dtype) + condition_tokens[b_inx, start_inx:end_inx] = input_image_tokens[input_img_inx].to( + condition_tokens.dtype + ) input_img_inx += 1 return condition_tokens - def forward(self, - hidden_states: torch.Tensor, - timestep: Union[int, float, torch.LongTensor], - input_ids: torch.Tensor, - input_img_latents: List[torch.Tensor], - input_image_sizes: Dict[int, List[int]], - attention_mask: torch.Tensor, - position_ids: torch.Tensor, - past_key_values: DynamicCache = None, - offload_transformer_block: bool = False, - attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ): + def forward( + self, + hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + input_ids: torch.Tensor, + input_img_latents: List[torch.Tensor], + input_image_sizes: Dict[int, List[int]], + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: DynamicCache = None, + offload_transformer_block: bool = False, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): """ The [`OmniGenTransformer2DModel`] forward method. @@ -408,12 +414,11 @@ def forward(self, `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`OmniGen2DModelOutput`] instead of a plain - tuple. + Whether or not to return a [`OmniGen2DModelOutput`] instead of a plain tuple. Returns: - If `return_dict` is True, an [`OmniGen2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. + If `return_dict` is True, an [`OmniGen2DModelOutput`] is returned, otherwise a `tuple` where the first + element is the sample tensor. """ @@ -437,19 +442,22 @@ def forward(self, time_token = self.time_token(timestep, dtype=hidden_states.dtype).unsqueeze(1) - condition_tokens = self.get_multimodal_embeddings(input_ids=input_ids, - input_img_latents=input_img_latents, - input_image_sizes=input_image_sizes, - ) + condition_tokens = self.get_multimodal_embeddings( + input_ids=input_ids, + input_img_latents=input_img_latents, + input_image_sizes=input_image_sizes, + ) if condition_tokens is not None: input_emb = torch.cat([condition_tokens, time_token, hidden_states], dim=1) else: input_emb = torch.cat([time_token, hidden_states], dim=1) - output = self.llm(inputs_embeds=input_emb, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - offload_transformer_block=offload_transformer_block) + output = self.llm( + inputs_embeds=input_emb, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + offload_transformer_block=offload_transformer_block, + ) output, past_key_values = output.last_hidden_state, output.past_key_values image_embedding = output[:, -num_tokens_for_output_image:] diff --git a/src/diffusers/pipelines/omnigen/kvcache_omnigen.py b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py index 34d98d939df2..ef2ca19e4455 100644 --- a/src/diffusers/pipelines/omnigen/kvcache_omnigen.py +++ b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py @@ -19,14 +19,13 @@ class OmniGenCache(DynamicCache): - def __init__(self, - num_tokens_for_img: int, - offload_kv_cache: bool = False) -> None: + def __init__(self, num_tokens_for_img: int, offload_kv_cache: bool = False) -> None: if not torch.cuda.is_available(): # print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!") # offload_kv_cache = False raise RuntimeError( - "OffloadedCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!") + "OffloadedCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!" + ) super().__init__() self.original_device = [] self.prefetch_stream = torch.cuda.Stream() @@ -76,14 +75,15 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + Parameters: key_states (`torch.Tensor`): The new key states to cache. @@ -101,8 +101,8 @@ def update( raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") elif len(self.key_cache) == layer_idx: # only cache the states for condition tokens - key_states = key_states[..., :-(self.num_tokens_for_img + 1), :] - value_states = value_states[..., :-(self.num_tokens_for_img + 1), :] + key_states = key_states[..., : -(self.num_tokens_for_img + 1), :] + value_states = value_states[..., : -(self.num_tokens_for_img + 1), :] # Update the number of seen tokens if layer_idx == 0: diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 0bb61e905bbf..f7f8827b2b7c 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -37,7 +37,6 @@ if is_torch_xla_available(): - XLA_AVAILABLE = True else: XLA_AVAILABLE = False @@ -63,12 +62,12 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, ): r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles @@ -148,11 +147,11 @@ class OmniGenPipeline( _callback_tensor_inputs = ["latents", "input_images_latents"] def __init__( - self, - transformer: OmniGenTransformer2DModel, - scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKL, - tokenizer: LlamaTokenizer, + self, + transformer: OmniGenTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + tokenizer: LlamaTokenizer, ): super().__init__() @@ -176,13 +175,14 @@ def __init__( self.default_sample_size = 128 def encod_input_iamges( - self, - input_pixel_values: List[torch.Tensor], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + self, + input_pixel_values: List[torch.Tensor], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): """ get the continues embedding of input images by VAE + Args: input_pixel_values: normlized pixel of input images device: @@ -198,17 +198,16 @@ def encod_input_iamges( return input_img_latents def check_inputs( - self, - prompt, - input_images, - height, - width, - use_kv_cache, - offload_kv_cache, - callback_on_step_end_tensor_inputs=None, - max_sequence_length=None, + self, + prompt, + input_images, + height, + width, + use_kv_cache, + offload_kv_cache, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, ): - if input_images is not None: if len(input_images) != len(prompt): raise ValueError( @@ -233,7 +232,7 @@ def check_inputs( ) if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" @@ -283,15 +282,15 @@ def disable_vae_tiling(self): self.vae.disable_tiling() def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, ): if latents is not None: return latents.to(device=device, dtype=dtype) @@ -328,7 +327,7 @@ def interrupt(self): def enable_transformer_block_cpu_offload(self, device: Union[torch.device, str] = "cuda"): torch_device = torch.device(device) for name, param in self.transformer.named_parameters(): - if 'layers' in name and 'layers.0' not in name: + if "layers" in name and "layers.0" not in name: param.data = param.data.cpu() else: param.data = param.data.to(torch_device) @@ -340,38 +339,39 @@ def enable_transformer_block_cpu_offload(self, device: Union[torch.device, str] @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( - self, - prompt: Union[str, List[str]], - input_images: Optional[ - Union[List[str], List[PIL.Image.Image], List[List[str]], List[List[PIL.Image.Image]]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - max_input_image_size: int = 1024, - timesteps: List[int] = None, - guidance_scale: float = 2.5, - img_guidance_scale: float = 1.6, - use_kv_cache: bool = True, - offload_kv_cache: bool = True, - offload_transformer_block: bool = False, - use_input_image_size_as_output: bool = False, - num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 120000, + self, + prompt: Union[str, List[str]], + input_images: Optional[ + Union[List[str], List[PIL.Image.Image], List[List[str]], List[List[PIL.Image.Image]]] + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + max_input_image_size: int = 1024, + timesteps: List[int] = None, + guidance_scale: float = 2.5, + img_guidance_scale: float = 1.6, + use_kv_cache: bool = True, + offload_kv_cache: bool = True, + offload_transformer_block: bool = False, + use_input_image_size_as_output: bool = False, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 120000, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. - If the input includes images, need to add placeholders `<|image_i|>` in the prompt to indicate the position of the i-th images. + The prompt or prompts to guide the image generation. If the input includes images, need to add + placeholders `<|image_i|>` in the prompt to indicate the position of the i-th images. input_images (`List[str]` or `List[List[str]]`, *optional*): The list of input images. We will replace the "<|image_i|>" in prompt with the i-th image in list. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -402,7 +402,8 @@ def __call__( offload_transformer_block (`bool`, *optional*, defaults to False): offload the transformer layers to cpu, which can save memory but slow down the generation use_input_image_size_as_output (bool, defaults to False): - whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task + whether to use the input image size as the output image size, which can be used for single-image input, + e.g., image editing task num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -476,18 +477,20 @@ def __call__( # 3. process multi-modal instructions if max_input_image_size != self.multimodal_processor.max_image_size: self.multimodal_processor = OmniGenMultiModalProcessor(self.tokenizer, max_image_size=max_input_image_size) - processed_data = self.multimodal_processor(prompt, - input_images, - height=height, - width=width, - use_img_cfg=use_img_cfg, - use_input_image_size_as_output=use_input_image_size_as_output) - processed_data['input_ids'] = processed_data['input_ids'].to(device) - processed_data['attention_mask'] = processed_data['attention_mask'].to(device) - processed_data['position_ids'] = processed_data['position_ids'].to(device) + processed_data = self.multimodal_processor( + prompt, + input_images, + height=height, + width=width, + use_img_cfg=use_img_cfg, + use_input_image_size_as_output=use_input_image_size_as_output, + ) + processed_data["input_ids"] = processed_data["input_ids"].to(device) + processed_data["attention_mask"] = processed_data["attention_mask"].to(device) + processed_data["position_ids"] = processed_data["position_ids"].to(device) # 4. Encode input images - input_img_latents = self.encod_input_iamges(processed_data['input_pixel_values'], device=device) + input_img_latents = self.encod_input_iamges(processed_data["input_pixel_values"], device=device) # 5. Prepare timesteps sigmas = np.linspace(1, 0, num_inference_steps + 1)[:num_inference_steps] @@ -497,7 +500,7 @@ def __call__( # 6. Prepare latents. if use_input_image_size_as_output: - height, width = processed_data['input_pixel_values'][0].shape[-2:] + height, width = processed_data["input_pixel_values"][0].shape[-2:] latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, @@ -511,7 +514,7 @@ def __call__( ) # 7. Prepare OmniGenCache - num_tokens_for_output_img = latents.size(-1) * latents.size(-2) // (self.transformer.patch_size ** 2) + num_tokens_for_output_img = latents.size(-1) * latents.size(-2) // (self.transformer.patch_size**2) cache = OmniGenCache(num_tokens_for_output_img, offload_kv_cache) if use_kv_cache else None self.transformer.llm.config.use_cache = use_kv_cache @@ -527,27 +530,29 @@ def __call__( noise_pred, cache = self.transformer( hidden_states=latent_model_input, timestep=timestep, - input_ids=processed_data['input_ids'], + input_ids=processed_data["input_ids"], input_img_latents=input_img_latents, - input_image_sizes=processed_data['input_image_sizes'], - attention_mask=processed_data['attention_mask'], - position_ids=processed_data['position_ids'], + input_image_sizes=processed_data["input_image_sizes"], + attention_mask=processed_data["attention_mask"], + position_ids=processed_data["position_ids"], attention_kwargs=attention_kwargs, past_key_values=cache, - offload_transformer_block=self.offload_transformer_block if hasattr(self, - 'offload_transformer_block') else offload_transformer_block, + offload_transformer_block=self.offload_transformer_block + if hasattr(self, "offload_transformer_block") + else offload_transformer_block, return_dict=False, ) # if use kv cache, don't need attention mask and position ids of condition tokens for next step if use_kv_cache: - if processed_data['input_ids'] is not None: - processed_data['input_ids'] = None - processed_data['attention_mask'] = processed_data['attention_mask'][..., - -(num_tokens_for_output_img + 1):, - :] # +1 is for the timestep token - processed_data['position_ids'] = processed_data['position_ids'][:, - -(num_tokens_for_output_img + 1):] + if processed_data["input_ids"] is not None: + processed_data["input_ids"] = None + processed_data["attention_mask"] = processed_data["attention_mask"][ + ..., -(num_tokens_for_output_img + 1) :, : + ] # +1 is for the timestep token + processed_data["position_ids"] = processed_data["position_ids"][ + :, -(num_tokens_for_output_img + 1) : + ] if num_cfg == 2: cond, uncond, img_cond = torch.split(noise_pred, len(noise_pred) // 3, dim=0) diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py index b350933d995c..d13e7a742379 100644 --- a/src/diffusers/pipelines/omnigen/processor_omnigen.py +++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py @@ -23,25 +23,19 @@ def crop_image(pil_image, max_image_size): """ - Crop the image so that its height and width does not exceed `max_image_size`, - while ensuring both the height and width are multiples of 16. + Crop the image so that its height and width does not exceed `max_image_size`, while ensuring both the height and + width are multiples of 16. """ while min(*pil_image.size) >= 2 * max_image_size: - pil_image = pil_image.resize( - tuple(x // 2 for x in pil_image.size), resample=Image.BOX - ) + pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) if max(*pil_image.size) > max_image_size: scale = max_image_size / max(*pil_image.size) - pil_image = pil_image.resize( - tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC - ) + pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) if min(*pil_image.size) < 16: scale = 16 / min(*pil_image.size) - pil_image = pil_image.resize( - tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC - ) + pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) arr = np.array(pil_image) crop_y1 = (arr.shape[0] % 16) // 2 @@ -50,28 +44,28 @@ def crop_image(pil_image, max_image_size): crop_x1 = (arr.shape[1] % 16) // 2 crop_x2 = arr.shape[1] % 16 - crop_x1 - arr = arr[crop_y1:arr.shape[0] - crop_y2, crop_x1:arr.shape[1] - crop_x2] + arr = arr[crop_y1 : arr.shape[0] - crop_y2, crop_x1 : arr.shape[1] - crop_x2] return Image.fromarray(arr) class OmniGenMultiModalProcessor: - def __init__(self, - text_tokenizer, - max_image_size: int = 1024): + def __init__(self, text_tokenizer, max_image_size: int = 1024): self.text_tokenizer = text_tokenizer self.max_image_size = max_image_size - self.image_transform = transforms.Compose([ - transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) - ]) + self.image_transform = transforms.Compose( + [ + transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) self.collator = OmniGenCollator() def process_image(self, image): if isinstance(image, str): - image = Image.open(image).convert('RGB') + image = Image.open(image).convert("RGB") return self.image_transform(image) def process_multi_modal_prompt(self, text, input_images): @@ -91,11 +85,13 @@ def process_multi_modal_prompt(self, text, input_images): image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags] unique_image_ids = sorted(set(image_ids)) - assert unique_image_ids == list(range(1, - len(unique_image_ids) + 1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}" + assert unique_image_ids == list( + range(1, len(unique_image_ids) + 1) + ), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}" # total images must be the same as the number of image tags - assert len(unique_image_ids) == len( - input_images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images" + assert ( + len(unique_image_ids) == len(input_images) + ), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images" input_images = [input_images[x - 1] for x in image_ids] @@ -112,24 +108,24 @@ def process_multi_modal_prompt(self, text, input_images): return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx} def add_prefix_instruction(self, prompt): - user_prompt = '<|user|>\n' - generation_prompt = 'Generate an image according to the following instructions\n' - assistant_prompt = '<|assistant|>\n<|diffusion|>' + user_prompt = "<|user|>\n" + generation_prompt = "Generate an image according to the following instructions\n" + assistant_prompt = "<|assistant|>\n<|diffusion|>" prompt_suffix = "<|end|>\n" prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}" return prompt - def __call__(self, - instructions: List[str], - input_images: List[List[str]] = None, - height: int = 1024, - width: int = 1024, - negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.", - use_img_cfg: bool = True, - separate_cfg_input: bool = False, - use_input_image_size_as_output: bool = False, - ) -> Dict: - + def __call__( + self, + instructions: List[str], + input_images: List[List[str]] = None, + height: int = 1024, + width: int = 1024, + negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.", + use_img_cfg: bool = True, + separate_cfg_input: bool = False, + use_input_image_size_as_output: bool = False, + ) -> Dict: if isinstance(instructions, str): instructions = [instructions] input_images = [input_images] @@ -156,8 +152,14 @@ def __call__(self, img_cfg_mllm_input = neg_mllm_input if use_input_image_size_as_output: - input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, - [mllm_input['pixel_values'][0].size(-2), mllm_input['pixel_values'][0].size(-1)])) + input_data.append( + ( + mllm_input, + neg_mllm_input, + img_cfg_mllm_input, + [mllm_input["pixel_values"][0].size(-2), mllm_input["pixel_values"][0].size(-1)], + ) + ) else: input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width])) @@ -175,14 +177,16 @@ def create_position(self, attention_mask, num_tokens_for_output_images): img_length = max(num_tokens_for_output_images) for mask in attention_mask: temp_l = torch.sum(mask) - temp_position = [0] * (text_length - temp_l) + list(range(temp_l + img_length + 1)) # we add a time embedding into the sequence, so add one more token + temp_position = [0] * (text_length - temp_l) + list( + range(temp_l + img_length + 1) + ) # we add a time embedding into the sequence, so add one more token position_ids.append(temp_position) return torch.LongTensor(position_ids) def create_mask(self, attention_mask, num_tokens_for_output_images): """ - OmniGen applies causal attention to each element in the sequence, but applies bidirectional attention within each image sequence - References: [OmniGen](https://arxiv.org/pdf/2409.11340) + OmniGen applies causal attention to each element in the sequence, but applies bidirectional attention within + each image sequence References: [OmniGen](https://arxiv.org/pdf/2409.11340) """ extended_mask = [] padding_images = [] @@ -261,9 +265,9 @@ def process_mllm_input(self, mllm_inputs, target_img_size): pixel_values, image_sizes = [], {} b_inx = 0 for x in mllm_inputs: - if x['pixel_values'] is not None: - pixel_values.extend(x['pixel_values']) - for size in x['image_sizes']: + if x["pixel_values"] is not None: + pixel_values.extend(x["pixel_values"]) + for size in x["image_sizes"]: if b_inx not in image_sizes: image_sizes[b_inx] = [size] else: @@ -271,7 +275,7 @@ def process_mllm_input(self, mllm_inputs, target_img_size): b_inx += 1 pixel_values = [x.unsqueeze(0) for x in pixel_values] - input_ids = [x['input_ids'] for x in mllm_inputs] + input_ids = [x["input_ids"] for x in mllm_inputs] padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes) position_ids = self.create_position(attention_mask, num_tokens_for_output_images) attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images) @@ -292,13 +296,20 @@ def __call__(self, features): mllm_inputs = mllm_inputs + cfg_mllm_inputs target_img_size = target_img_size + target_img_size - all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input( - mllm_inputs, target_img_size) - - data = {"input_ids": all_padded_input_ids, - "attention_mask": all_attention_mask, - "position_ids": all_position_ids, - "input_pixel_values": all_pixel_values, - "input_image_sizes": all_image_sizes, - } + ( + all_padded_input_ids, + all_position_ids, + all_attention_mask, + all_padding_images, + all_pixel_values, + all_image_sizes, + ) = self.process_mllm_input(mllm_inputs, target_img_size) + + data = { + "input_ids": all_padded_input_ids, + "attention_mask": all_attention_mask, + "position_ids": all_position_ids, + "input_pixel_values": all_pixel_values, + "input_image_sizes": all_image_sizes, + } return data diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index d8e2f34443b5..1bd24da62f2f 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -481,6 +481,7 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) + class OmniGenTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py index 3d883fa9011d..3edaf9cf3110 100644 --- a/tests/pipelines/omnigen/test_pipeline_omnigen.py +++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py @@ -24,16 +24,18 @@ class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin): "guidance_scale", ] ) - batch_params = frozenset(["prompt", ]) + batch_params = frozenset( + [ + "prompt", + ] + ) def get_dummy_components(self): torch.manual_seed(0) transformer_config = { "_name_or_path": "Phi-3-vision-128k-instruct", - "architectures": [ - "Phi3ForCausalLM" - ], + "architectures": ["Phi3ForCausalLM"], "attention_dropout": 0.0, "bos_token_id": 1, "eos_token_id": 2, @@ -97,7 +99,7 @@ def get_dummy_components(self): 64.760009765625, 64.80001068115234, 64.81001281738281, - 64.81001281738281 + 64.81001281738281, ], "short_factor": [ 1.05, @@ -147,9 +149,9 @@ def get_dummy_components(self): 2.9499999999999975, 3.049999999999997, 3.049999999999997, - 3.049999999999997 + 3.049999999999997, ], - "type": "su" + "type": "su", }, "rope_theta": 10000.0, "sliding_window": 131072, @@ -158,7 +160,7 @@ def get_dummy_components(self): "transformers_version": "4.38.1", "use_cache": True, "vocab_size": 32064, - "_attn_implementation": "sdpa" + "_attn_implementation": "sdpa", } transformer = OmniGenTransformer2DModel( transformer_config=transformer_config, @@ -167,7 +169,6 @@ def get_dummy_components(self): pos_embed_max_size=192, ) - torch.manual_seed(0) vae = AutoencoderKL( sample_size=32, @@ -177,7 +178,7 @@ def get_dummy_components(self): layers_per_block=1, latent_channels=4, norm_num_groups=1, - up_block_types = ["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], ) scheduler = FlowMatchEulerDiscreteScheduler() @@ -220,8 +221,6 @@ def test_inference(self): self.assertEqual(generated_image.shape, (16, 16, 3)) - - @slow @require_torch_gpu class OmniGenPipelineSlowTests(unittest.TestCase): @@ -262,16 +261,18 @@ def test_omnigen_inference(self): image_slice = image[0, :10, :10] expected_slice = np.array( - [[0.25806782, 0.28012177, 0.27807158], - [0.25740036, 0.2677201, 0.26857468], - [0.258638, 0.27035138, 0.26633185], - [0.2541029, 0.2636156, 0.26373306], - [0.24975497, 0.2608987, 0.2617477 ], - [0.25102, 0.26192215, 0.262023 ], - [0.24452701, 0.25664824, 0.259144 ], - [0.2419573, 0.2574909, 0.25996095], - [0.23953134, 0.25292695, 0.25652167], - [0.23523712, 0.24710432, 0.25460982]], + [ + [0.25806782, 0.28012177, 0.27807158], + [0.25740036, 0.2677201, 0.26857468], + [0.258638, 0.27035138, 0.26633185], + [0.2541029, 0.2636156, 0.26373306], + [0.24975497, 0.2608987, 0.2617477], + [0.25102, 0.26192215, 0.262023], + [0.24452701, 0.25664824, 0.259144], + [0.2419573, 0.2574909, 0.25996095], + [0.23953134, 0.25292695, 0.25652167], + [0.23523712, 0.24710432, 0.25460982], + ], dtype=np.float32, ) From 5925cb9f7b2f0dea222ad65602c4fe663d2497ba Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Tue, 10 Dec 2024 14:23:32 +0800 Subject: [PATCH 20/55] Update docs/source/en/api/models/omnigen_transformer.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/models/omnigen_transformer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/models/omnigen_transformer.md b/docs/source/en/api/models/omnigen_transformer.md index d2df6c55e68b..ee700a04bdae 100644 --- a/docs/source/en/api/models/omnigen_transformer.md +++ b/docs/source/en/api/models/omnigen_transformer.md @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # OmniGenTransformer2DModel -A Transformer model accept multi-modal instruction to generate image from [OmniGen](https://github.com/VectorSpaceLab/OmniGen/). +A Transformer model that accepts multimodal instructions to generate images for [OmniGen](https://github.com/VectorSpaceLab/OmniGen/). ## OmniGenTransformer2DModel From 56aa8211747c026a9eef0aa9400f66b651dc4cc5 Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Tue, 10 Dec 2024 14:24:40 +0800 Subject: [PATCH 21/55] Update docs/source/en/using-diffusers/omnigen.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/using-diffusers/omnigen.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md index df85fba74395..723e6e3109cc 100644 --- a/docs/source/en/using-diffusers/omnigen.md +++ b/docs/source/en/using-diffusers/omnigen.md @@ -295,7 +295,7 @@ image ## Optimization when inputting multiple images -For text-to-image task, OmniGen requires minimal memory and time costs (9G memory and 31s for a 1024*1024 image on A800 GPU). +For text-to-image task, OmniGen requires minimal memory and time costs (9GB memory and 31s for a 1024*1024 image on A800 GPU). However, when using input images, the computational cost increases. Here are some guidelines to help you reduce computational costs when input multiple images. The experiments are conducted on A800 GPU and input two images to OmniGen. From 1e33ca835647fc959677d37c71ffb2723b7de8f9 Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Tue, 10 Dec 2024 14:24:49 +0800 Subject: [PATCH 22/55] Update docs/source/en/using-diffusers/omnigen.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/using-diffusers/omnigen.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md index 723e6e3109cc..c238cb55a90a 100644 --- a/docs/source/en/using-diffusers/omnigen.md +++ b/docs/source/en/using-diffusers/omnigen.md @@ -301,7 +301,7 @@ However, when using input images, the computational cost increases. Here are some guidelines to help you reduce computational costs when input multiple images. The experiments are conducted on A800 GPU and input two images to OmniGen. -### inference speed +### Inference speed - `use_kv_cache=True`: `use_kv_cache` will store key and value states of the input conditions to compute attention without redundant computations. From c81a84dc7655ba5b4080e4a208e27a28862110ca Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Tue, 10 Dec 2024 15:31:40 +0800 Subject: [PATCH 23/55] update docs --- docs/source/en/_toctree.yml | 2 - docs/source/en/api/pipelines/omnigen.md | 2 - .../en/using-diffusers/multimodal2img.md | 115 ------------------ docs/source/en/using-diffusers/omnigen.md | 27 +++- 4 files changed, 24 insertions(+), 122 deletions(-) delete mode 100644 docs/source/en/using-diffusers/multimodal2img.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 7689283e509f..3e614a370c13 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -51,8 +51,6 @@ title: Text or image-to-video - local: using-diffusers/depth2img title: Depth-to-image - - local: using-diffusers/multimodal2img - title: Multimodal-to-image title: Generative tasks - sections: - local: using-diffusers/overview_techniques diff --git a/docs/source/en/api/pipelines/omnigen.md b/docs/source/en/api/pipelines/omnigen.md index 52bcf8d2c3b2..f6b45b1b014d 100644 --- a/docs/source/en/api/pipelines/omnigen.md +++ b/docs/source/en/api/pipelines/omnigen.md @@ -52,8 +52,6 @@ This pipeline was contributed by [staoxiao](https://github.com/staoxiao). The or ## Inference -Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency. - First, load the pipeline: ```python diff --git a/docs/source/en/using-diffusers/multimodal2img.md b/docs/source/en/using-diffusers/multimodal2img.md deleted file mode 100644 index 0c0bbc01ee9f..000000000000 --- a/docs/source/en/using-diffusers/multimodal2img.md +++ /dev/null @@ -1,115 +0,0 @@ - - -# Multi-modal instruction to image - -[[open-in-colab]] - - - -Multimodal instructions mean you can input arbitrarily interleaved text and image inputs as conditions to guide image generation. -You can input multiple images and use prompts to describe the desired output. -This approach is more flexible than using only text or images. - -## Examples - - -Take `OmniGenPipeline` as an example: the input can be a text-image sequence to create new images, he input can be a text-image sequence, with images inserted into the text prompt via special placeholder `<|image_i|>`. - -```py -import torch -from diffusers import OmniGenPipeline -from diffusers.utils import load_image - -pipe = OmniGenPipeline.from_pretrained( - "Shitao/OmniGen-v1-diffusers", - torch_dtype=torch.bfloat16 -) -pipe.to("cuda") - -prompt="A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <|image_1|>. The woman is the woman on the left of <|image_2|>" -input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.jpg") -input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/4.jpg") -input_images=[input_image_1, input_image_2] -image = pipe( - prompt=prompt, - input_images=input_images, - height=1024, - width=1024, - guidance_scale=2.5, - img_guidance_scale=1.6, - generator=torch.Generator(device="cpu").manual_seed(666)).images[0] -image -``` -
-
- -
input_image_1
-
-
- -
input_image_2
-
-
- -
generated image
-
-
- - -```py -import torch -from diffusers import OmniGenPipeline -from diffusers.utils import load_image - -pipe = OmniGenPipeline.from_pretrained( - "Shitao/OmniGen-v1-diffusers", - torch_dtype=torch.bfloat16 -) -pipe.to("cuda") - - -prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <|image_1|>. The long-sleeve blouse and a pleated skirt are <|image_2|>." -input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg") -input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg") -input_images=[input_image_1, input_image_2] -image = pipe( - prompt=prompt, - input_images=input_images, - height=1024, - width=1024, - guidance_scale=2.5, - img_guidance_scale=1.6, - generator=torch.Generator(device="cpu").manual_seed(666)).images[0] -``` - -
-
- -
person image
-
-
- -
clothe image
-
-
- -
generated image
-
-
- - -The output image is a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object that can be saved: - -```py -image.save("generated_image.png") -``` \ No newline at end of file diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md index c238cb55a90a..6742fef24b8d 100644 --- a/docs/source/en/using-diffusers/omnigen.md +++ b/docs/source/en/using-diffusers/omnigen.md @@ -303,6 +303,13 @@ Here are some guidelines to help you reduce computational costs when input multi ### Inference speed +| Parameter | Inference Time | +|--------------------------|----------------| +| use_kv_cache=True | 90s | +| use_kv_cache=False | 221s | +| max_input_image_size=1024| 90s | +| max_input_image_size=512 | 58s | + - `use_kv_cache=True`: `use_kv_cache` will store key and value states of the input conditions to compute attention without redundant computations. The default value is True, and OmniGen will offload the kv cache to cpu default. @@ -314,17 +321,31 @@ Here are some guidelines to help you reduce computational costs when input multi - `max_input_image_size=1024`: the inference time is 1m30s. - `max_input_image_size=512`: the inference time is 58s. + + + + + ### Memory + +| Method | Memory Usage | +|---------------------------------------------|--------------| +| pipe.to("cuda") | 31GB | +| pipe.enable_model_cpu_offload() | 28GB | +| pipe.enable_transformer_block_cpu_offload() | 25GB | +| pipe.enable_sequential_cpu_offload() | 11GB | + - `pipe.enable_model_cpu_offload()`: - Without enabling cpu offloading, memory usage is `31 GB` - With enabling cpu offloading, memory usage is `28 GB` -- `offload_transformer_block=True`: - - offload transformer block to reduce memory usage +- `pipe.enable_transformer_block_cpu_offload()`: + - Offload transformer block to reduce memory usage - When enabled, memory usage is under `25 GB` - `pipe.enable_sequential_cpu_offload()`: - - significantly reduce memory usage at the cost of slow inference + - Significantly reduce memory usage at the cost of slow inference - When enabled, memory usage is under `11 GB` + From 3867830c8594deead608e6baaaeb820912816f29 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 12 Dec 2024 00:02:43 +0100 Subject: [PATCH 24/55] revert changes to examples/ --- .../train_cogvideox_image_to_video_lora.py | 3 +- examples/cogvideo/train_cogvideox_lora.py | 3 +- .../community/pipeline_flux_rf_inversion.py | 1061 +++ .../geodiff_molecule_conformation.ipynb | 7230 ++++++++--------- examples/research_projects/gligen/demo.ipynb | 13 +- 5 files changed, 4680 insertions(+), 3630 deletions(-) create mode 100644 examples/community/pipeline_flux_rf_inversion.py diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py index 1f055bcecbed..65dcf050fceb 100644 --- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py +++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py @@ -872,10 +872,9 @@ def prepare_rotary_positional_embeddings( crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), temporal_size=num_frames, + device=device, ) - freqs_cos = freqs_cos.to(device=device) - freqs_sin = freqs_sin.to(device=device) return freqs_cos, freqs_sin diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index e591e0ee5900..f1b2dff53cb2 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -894,10 +894,9 @@ def prepare_rotary_positional_embeddings( crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), temporal_size=num_frames, + device=device, ) - freqs_cos = freqs_cos.to(device=device) - freqs_sin = freqs_sin.to(device=device) return freqs_cos, freqs_sin diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py new file mode 100644 index 000000000000..f09160c4571d --- /dev/null +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -0,0 +1,1061 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# modeled after RF Inversion: https://rf-inversion.github.io/, authored by Litu Rout, Yujia Chen, Nataniel Ruiz, +# Constantine Caramanis, Sanjay Shakkottai and Wen-Sheng Chu. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.models.transformers import FluxTransformer2DModel +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> import requests + >>> import PIL + >>> from io import BytesIO + >>> from diffusers import DiffusionPipeline + + >>> pipe = DiffusionPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-dev", + ... torch_dtype=torch.bfloat16, + ... custom_pipeline="pipeline_flux_rf_inversion") + >>> pipe.to("cuda") + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg" + >>> image = download_image(img_url) + + >>> inverted_latents, image_latents, latent_image_ids = pipe.invert(image=image, num_inversion_steps=28, gamma=0.5) + + >>> edited_image = pipe( + ... prompt="a tomato", + ... inverted_latents=inverted_latents, + ... image_latents=image_latents, + ... latent_image_ids=latent_image_ids, + ... start_timestep=0, + ... stop_timestep=.25, + ... num_inference_steps=28, + ... eta=0.9, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class RFInversionFluxPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The Flux pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + @torch.no_grad() + # Modified from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.encode_image + def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="default", crops_coords=None): + image = self.image_processor.preprocess( + image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + resized = self.image_processor.postprocess(image=image, output_type="pil") + + if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5: + logger.warning( + "Your input images far exceed the default resolution of the underlying diffusion model. " + "The output images may contain severe artifacts! " + "Consider down-sampling the input using the `height` and `width` parameters" + ) + image = image.to(dtype) + + x0 = self.vae.encode(image.to(self.device)).latent_dist.sample() + x0 = (x0 - self.vae.config.shift_factor) * self.vae.config.scaling_factor + x0 = x0.to(dtype) + return x0, resized + + def check_inputs( + self, + prompt, + prompt_2, + inverted_latents, + image_latents, + latent_image_ids, + height, + width, + start_timestep, + stop_timestep, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + if inverted_latents is not None and (image_latents is None or latent_image_ids is None): + raise ValueError( + "If `inverted_latents` are provided, `image_latents` and `latent_image_ids` also have to be passed. " + ) + # check start_timestep and stop_timestep + if start_timestep < 0 or start_timestep > stop_timestep: + raise ValueError(f"`start_timestep` should be in [0, stop_timestep] but is {start_timestep}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents_inversion( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + image_latents, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength=1.0): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + sigmas = self.scheduler.sigmas[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, sigmas, num_inference_steps - t_start + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + inverted_latents: Optional[torch.FloatTensor] = None, + image_latents: Optional[torch.FloatTensor] = None, + latent_image_ids: Optional[torch.FloatTensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 1.0, + strength: float = 1.0, + start_timestep: float = 0, + stop_timestep: float = 0.25, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + inverted_latents (`torch.Tensor`, *optional*): + The inverted latents from `pipe.invert`. + image_latents (`torch.Tensor`, *optional*): + The image latents from `pipe.invert`. + latent_image_ids (`torch.Tensor`, *optional*): + The latent image ids from `pipe.invert`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + eta (`float`, *optional*, defaults to 1.0): + The controller guidance, balancing faithfulness & editability: + higher eta - better faithfullness, less editability. For more significant edits, lower the value of eta. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + inverted_latents, + image_latents, + latent_image_ids, + height, + width, + start_timestep, + stop_timestep, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + do_rf_inversion = inverted_latents is not None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + if do_rf_inversion: + latents = inverted_latents + else: + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + if do_rf_inversion: + start_timestep = int(start_timestep * num_inference_steps) + stop_timestep = min(int(stop_timestep * num_inference_steps), num_inference_steps) + timesteps, sigmas, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if do_rf_inversion: + y_0 = image_latents.clone() + # 6. Denoising loop / Controlled Reverse ODE, Algorithm 2 from: https://arxiv.org/pdf/2410.10792 + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if do_rf_inversion: + # ti (current timestep) as annotated in algorithm 2 - i/num_inference_steps. + t_i = 1 - t / 1000 + dt = torch.tensor(1 / (len(timesteps) - 1), device=device) + + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + latents_dtype = latents.dtype + if do_rf_inversion: + v_t = -noise_pred + v_t_cond = (y_0 - latents) / (1 - t_i) + eta_t = eta if start_timestep <= i < stop_timestep else 0.0 + if start_timestep <= i < stop_timestep: + # controlled vector field + v_hat_t = v_t + eta * (v_t_cond - v_t) + + else: + v_hat_t = v_t + + # SDE Eq: 17 from https://arxiv.org/pdf/2410.10792 + latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1]) + else: + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) + + @torch.no_grad() + def invert( + self, + image: PipelineImageInput, + source_prompt: str = "", + source_guidance_scale=0.0, + num_inversion_steps: int = 28, + strength: float = 1.0, + gamma: float = 0.5, + height: Optional[int] = None, + width: Optional[int] = None, + timesteps: List[int] = None, + dtype: Optional[torch.dtype] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Performs Algorithm 1: Controlled Forward ODE from https://arxiv.org/pdf/2410.10792 + Args: + image (`PipelineImageInput`): + Input for the image(s) that are to be edited. Multiple input images have to default to the same aspect + ratio. + source_prompt (`str` or `List[str]`, *optional* defaults to an empty prompt as done in the original paper): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + source_guidance_scale (`float`, *optional*, defaults to 0.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). For this algorithm, it's better to keep it 0. + num_inversion_steps (`int`, *optional*, defaults to 28): + The number of discretization steps. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + gamma (`float`, *optional*, defaults to 0.5): + The controller guidance for the forward ODE, balancing faithfulness & editability: + higher eta - better faithfullness, less editability. For more significant edits, lower the value of eta. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + """ + dtype = dtype or self.text_encoder.dtype + batch_size = 1 + self._joint_attention_kwargs = joint_attention_kwargs + num_channels_latents = self.transformer.config.in_channels // 4 + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + device = self._execution_device + + # 1. prepare image + image_latents, _ = self.encode_image(image, height=height, width=width, dtype=dtype) + image_latents, latent_image_ids = self.prepare_latents_inversion( + batch_size, num_channels_latents, height, width, dtype, device, image_latents + ) + + # 2. prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inversion_steps, num_inversion_steps) + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inversion_steps = retrieve_timesteps( + self.scheduler, + num_inversion_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + timesteps, sigmas, num_inversion_steps = self.get_timesteps(num_inversion_steps, strength) + + # 3. prepare text embeddings + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=source_prompt, + prompt_2=source_prompt, + device=device, + ) + # 4. handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], source_guidance_scale, device=device, dtype=torch.float32) + else: + guidance = None + + # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt + Y_t = image_latents + y_1 = torch.randn_like(Y_t) + N = len(sigmas) + + # forward ODE loop + with self.progress_bar(total=N - 1) as progress_bar: + for i in range(N - 1): + t_i = torch.tensor(i / (N), dtype=Y_t.dtype, device=device) + timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(batch_size) + + # get the unconditional vector field + u_t_i = self.transformer( + hidden_states=Y_t, + timestep=timestep, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # get the conditional vector field + u_t_i_cond = (y_1 - Y_t) / (1 - t_i) + + # controlled vector field + # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt + u_hat_t_i = u_t_i + gamma * (u_t_i_cond - u_t_i) + Y_t = Y_t + u_hat_t_i * (sigmas[i] - sigmas[i + 1]) + progress_bar.update() + + # return the inverted latents (start point for the denoising loop), encoded image & latent image ids + return Y_t, image_latents, latent_image_ids diff --git a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb index 03f58f1f2f63..bde093802a5d 100644 --- a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb +++ b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb @@ -1,3660 +1,3652 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "F88mignPnalS" - }, - "source": [ - "# Introduction\n", - "\n", - "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n", - "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n", - "\n", - "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n", - "\n", - "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n", - "\n", - "> Colab made by [natolambert](https://twitter.com/natolambert).\n", - "\n", - "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7cnwXMocnuzB" - }, - "source": [ - "## Installations\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ff9SxWnaNId9" - }, - "source": [ - "### Install Conda" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1g_6zOabItDk" - }, - "source": [ - "Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "K0ofXobG5Y-X", - "outputId": "572c3d25-6f19-4c1e-83f5-a1d084a3207f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "nvcc: NVIDIA (R) Cuda compiler driver\n", - "Copyright (c) 2005-2021 NVIDIA Corporation\n", - "Built on Sun_Feb_14_21:12:58_PST_2021\n", - "Cuda compilation tools, release 11.2, V11.2.152\n", - "Build cuda_11.2.r11.2/compiler.29618528_0\n" - ] - } - ], - "source": [ - "!nvcc --version" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VfthW90vI0nw" - }, - "source": [ - "Install Conda for some more complex dependencies for geometric networks." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "2WNFzSnbiE0k", - "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "!pip install -q condacolab" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NUsbWYCUI7Km" - }, - "source": [ - "Setup Conda" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FZelreINdmd0", - "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "✨🍰✨ Everything looks OK!\n" - ] - } - ], - "source": [ - "import condacolab\n", - "\n", - "\n", - "condacolab.install()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JzDHaPU7I9Sn" - }, - "source": [ - "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JMxRjHhL7w8V", - "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n", - "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "\n", - "## Package Plan ##\n", - "\n", - " environment location: /usr/local\n", - "\n", - " added / updated specs:\n", - " - cudatoolkit=11.1\n", - " - pytorch\n", - " - torchaudio\n", - " - torchvision\n", - "\n", - "\n", - "The following packages will be downloaded:\n", - "\n", - " package | build\n", - " ---------------------------|-----------------\n", - " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n", - " ------------------------------------------------------------\n", - " Total: 960 KB\n", - "\n", - "The following packages will be UPDATED:\n", - "\n", - " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n", - "\n", - "\n", - "\n", - "Downloading and Extracting Packages\n", - "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n", - "Preparing transaction: / \b\bdone\n", - "Verifying transaction: \\ \b\bdone\n", - "Executing transaction: / \b\bdone\n", - "Retrieving notices: ...working... done\n" - ] - } - ], - "source": [ - "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n", - "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QDS6FPZ0Tu5b" - }, - "source": [ - "Need to remove a pathspec for colab that specifies the incorrect cuda version." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "dq1lxR10TtrR", - "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n" - ] - } - ], - "source": [ - "!rm /usr/local/conda-meta/pinned" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Z1L3DdZOJB30" - }, - "source": [ - "Install torch geometric (used in the model later)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "D5ukfCOWfjzK", - "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", - "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "\n", - "## Package Plan ##\n", - "\n", - " environment location: /usr/local\n", - "\n", - " added / updated specs:\n", - " - pytorch-geometric=1.7.2\n", - "\n", - "\n", - "The following packages will be downloaded:\n", - "\n", - " package | build\n", - " ---------------------------|-----------------\n", - " decorator-4.4.2 | py_0 11 KB conda-forge\n", - " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n", - " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n", - " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n", - " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n", - " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n", - " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n", - " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n", - " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n", - " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n", - " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n", - " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n", - " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n", - " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n", - " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n", - " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n", - " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n", - " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n", - " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n", - " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n", - " ------------------------------------------------------------\n", - " Total: 55.9 MB\n", - "\n", - "The following NEW packages will be INSTALLED:\n", - "\n", - " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n", - " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n", - " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n", - " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n", - " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n", - " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n", - " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n", - " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n", - " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n", - " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n", - " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n", - " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n", - " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n", - " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n", - " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n", - " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n", - " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n", - " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n", - " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n", - "\n", - "The following packages will be DOWNGRADED:\n", - "\n", - " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n", - "\n", - "\n", - "\n", - "Downloading and Extracting Packages\n", - "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n", - "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n", - "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n", - "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n", - "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n", - "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n", - "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n", - "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n", - "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n", - "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n", - "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n", - "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n", - "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n", - "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n", - "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n", - "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n", - "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n", - "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n", - "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n", - "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n", - "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n", - "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", - "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", - "Retrieving notices: ...working... done\n" - ] - } - ], - "source": [ - "!conda install -c rusty1s pytorch-geometric=1.7.2" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ppxv6Mdkalbc" - }, - "source": [ - "### Install Diffusers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "mgQA_XN-XGY2", - "outputId": "85392615-b6a4-4052-9d2a-79604be62c94" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/content\n", - "Cloning into 'diffusers'...\n", - "remote: Enumerating objects: 9298, done.\u001b[K\n", - "remote: Counting objects: 100% (40/40), done.\u001b[K\n", - "remote: Compressing objects: 100% (23/23), done.\u001b[K\n", - "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n", - "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n", - "Resolving deltas: 100% (6168/6168), done.\n", - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "%cd /content\n", - "\n", - "# install latest HF diffusers (will update to the release once added)\n", - "!git clone https://github.com/huggingface/diffusers.git\n", - "!pip install -q /content/diffusers\n", - "\n", - "# dependencies for diffusers\n", - "!pip install -q datasets transformers" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LZO6AJKuJKO8" - }, - "source": [ - "Check that torch is installed correctly and utilizing the GPU in the colab" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 53 - }, - "id": "gZt7BNi1e1PA", - "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889" - }, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "F88mignPnalS" + }, + "source": [ + "# Introduction\n", + "\n", + "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n", + "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n", + "\n", + "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n", + "\n", + "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n", + "\n", + "> Colab made by [natolambert](https://twitter.com/natolambert).\n", + "\n", + "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n" + ] }, { - "data": { - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" + "cell_type": "markdown", + "metadata": { + "id": "7cnwXMocnuzB" }, - "text/plain": [ - "'1.8.2'" + "source": [ + "## Installations\n", + "\n" ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import torch\n", - "\n", - "\n", - "print(torch.cuda.is_available())\n", - "torch.__version__" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KLE7CqlfJNUO" - }, - "source": [ - "### Install Chemistry-specific Dependencies\n", - "\n", - "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0CPv_NvehRz3", - "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting rdkit\n", - " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n", - "Installing collected packages: rdkit\n", - "Successfully installed rdkit-2022.3.5\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "!pip install rdkit" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "88GaDbDPxJ5I" - }, - "source": [ - "### Get viewer from nglview\n", - "\n", - "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n", - "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n", - "The rdmol in this object is a source of ground truth for the generated molecules.\n", - "\n", - "You will use one rendering function from nglviewer later!\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "id": "jcl8GCS2mz6t", - "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting nglview\n", - " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n", - "Collecting jupyterlab-widgets\n", - " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipywidgets>=7\n", - " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting widgetsnbextension~=4.0\n", - " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipython>=6.1.0\n", - " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ipykernel>=4.5.1\n", - " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting traitlets>=4.3.1\n", - " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n", - "Collecting pyzmq>=17\n", - " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting matplotlib-inline>=0.1\n", - " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n", - "Collecting tornado>=6.1\n", - " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting nest-asyncio\n", - " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n", - "Collecting debugpy>=1.0\n", - " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting psutil\n", - " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting jupyter-client>=6.1.12\n", - " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting pickleshare\n", - " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n", - "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n", - "Collecting backcall\n", - " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n", - "Collecting pexpect>4.3\n", - " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting pygments\n", - " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting jedi>=0.16\n", - " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n", - " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n", - "Collecting parso<0.9.0,>=0.8.0\n", - " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n", - "Collecting entrypoints\n", - " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n", - "Collecting jupyter-core>=4.9.2\n", - " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting ptyprocess>=0.5\n", - " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n", - "Collecting wcwidth\n", - " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n", - "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n", - "Building wheels for collected packages: nglview\n", - " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n", - " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n", - "Successfully built nglview\n", - "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n", - "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - }, - { - "data": { - "application/vnd.colab-display-data+json": { - "pip_warning": { - "packages": [ - "pexpect", - "pickleshare", - "wcwidth" - ] - } + }, + { + "cell_type": "markdown", + "source": [ + "### Install Conda" + ], + "metadata": { + "id": "ff9SxWnaNId9" } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "!pip install nglview" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8t8_e_uVLdKB" - }, - "source": [ - "## Create a diffusion model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "G0rMncVtNSqU" - }, - "source": [ - "### Model class(es)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "L5FEXz5oXkzt" - }, - "source": [ - "Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "-3-P4w5sXkRU" - }, - "outputs": [], - "source": [ - "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n", - "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n", - "from dataclasses import dataclass\n", - "from typing import Callable, Tuple, Union\n", - "\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn.functional as F\n", - "from torch import Tensor, nn\n", - "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n", - "from torch_geometric.nn import MessagePassing, radius, radius_graph\n", - "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n", - "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n", - "from torch_scatter import scatter_add\n", - "from torch_sparse import SparseTensor, coalesce\n", - "\n", - "from diffusers.configuration_utils import ConfigMixin, register_to_config\n", - "from diffusers.modeling_utils import ModelMixin\n", - "from diffusers.utils import BaseOutput\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EzJQXPN_XrMX" - }, - "source": [ - "Helper classes" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "oR1Y56QiLY90" - }, - "outputs": [], - "source": [ - "@dataclass\n", - "class MoleculeGNNOutput(BaseOutput):\n", - " \"\"\"\n", - " Args:\n", - " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n", - " Hidden states output. Output of last layer of model.\n", - " \"\"\"\n", - "\n", - " sample: torch.Tensor\n", - "\n", - "\n", - "class MultiLayerPerceptron(nn.Module):\n", - " \"\"\"\n", - " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n", - " Args:\n", - " input_dim (int): input dimension\n", - " hidden_dim (list of int): hidden dimensions\n", - " activation (str or function, optional): activation function\n", - " dropout (float, optional): dropout rate\n", - " \"\"\"\n", - "\n", - " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n", - " super(MultiLayerPerceptron, self).__init__()\n", - "\n", - " self.dims = [input_dim] + hidden_dims\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " print(f\"Warning, activation passed {activation} is not string and ignored\")\n", - " self.activation = None\n", - " if dropout > 0:\n", - " self.dropout = nn.Dropout(dropout)\n", - " else:\n", - " self.dropout = None\n", - "\n", - " self.layers = nn.ModuleList()\n", - " for i in range(len(self.dims) - 1):\n", - " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n", - "\n", - " def forward(self, x):\n", - " \"\"\"\"\"\"\n", - " for i, layer in enumerate(self.layers):\n", - " x = layer(x)\n", - " if i < len(self.layers) - 1:\n", - " if self.activation:\n", - " x = self.activation(x)\n", - " if self.dropout:\n", - " x = self.dropout(x)\n", - " return x\n", - "\n", - "\n", - "class ShiftedSoftplus(torch.nn.Module):\n", - " def __init__(self):\n", - " super(ShiftedSoftplus, self).__init__()\n", - " self.shift = torch.log(torch.tensor(2.0)).item()\n", - "\n", - " def forward(self, x):\n", - " return F.softplus(x) - self.shift\n", - "\n", - "\n", - "class CFConv(MessagePassing):\n", - " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n", - " super(CFConv, self).__init__(aggr=\"add\")\n", - " self.lin1 = Linear(in_channels, num_filters, bias=False)\n", - " self.lin2 = Linear(num_filters, out_channels)\n", - " self.nn = mlp\n", - " self.cutoff = cutoff\n", - " self.smooth = smooth\n", - "\n", - " self.reset_parameters()\n", - "\n", - " def reset_parameters(self):\n", - " torch.nn.init.xavier_uniform_(self.lin1.weight)\n", - " torch.nn.init.xavier_uniform_(self.lin2.weight)\n", - " self.lin2.bias.data.fill_(0)\n", - "\n", - " def forward(self, x, edge_index, edge_length, edge_attr):\n", - " if self.smooth:\n", - " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n", - " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n", - " else:\n", - " C = (edge_length <= self.cutoff).float()\n", - " W = self.nn(edge_attr) * C.view(-1, 1)\n", - "\n", - " x = self.lin1(x)\n", - " x = self.propagate(edge_index, x=x, W=W)\n", - " x = self.lin2(x)\n", - " return x\n", - "\n", - " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n", - " return x_j * W\n", - "\n", - "\n", - "class InteractionBlock(torch.nn.Module):\n", - " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n", - " super(InteractionBlock, self).__init__()\n", - " mlp = Sequential(\n", - " Linear(num_gaussians, num_filters),\n", - " ShiftedSoftplus(),\n", - " Linear(num_filters, num_filters),\n", - " )\n", - " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n", - " self.act = ShiftedSoftplus()\n", - " self.lin = Linear(hidden_channels, hidden_channels)\n", - "\n", - " def forward(self, x, edge_index, edge_length, edge_attr):\n", - " x = self.conv(x, edge_index, edge_length, edge_attr)\n", - " x = self.act(x)\n", - " x = self.lin(x)\n", - " return x\n", - "\n", - "\n", - "class SchNetEncoder(Module):\n", - " def __init__(\n", - " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n", - " ):\n", - " super().__init__()\n", - "\n", - " self.hidden_channels = hidden_channels\n", - " self.num_filters = num_filters\n", - " self.num_interactions = num_interactions\n", - " self.cutoff = cutoff\n", - "\n", - " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n", - "\n", - " self.interactions = ModuleList()\n", - " for _ in range(num_interactions):\n", - " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n", - " self.interactions.append(block)\n", - "\n", - " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n", - " if embed_node:\n", - " assert z.dim() == 1 and z.dtype == torch.long\n", - " h = self.embedding(z)\n", - " else:\n", - " h = z\n", - " for interaction in self.interactions:\n", - " h = h + interaction(h, edge_index, edge_length, edge_attr)\n", - "\n", - " return h\n", - "\n", - "\n", - "class GINEConv(MessagePassing):\n", - " \"\"\"\n", - " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n", - " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n", - " \"\"\"\n", - "\n", - " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n", - " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n", - " self.nn = mlp\n", - " self.initial_eps = eps\n", - "\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " self.activation = None\n", - "\n", - " if train_eps:\n", - " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n", - " else:\n", - " self.register_buffer(\"eps\", torch.Tensor([eps]))\n", - "\n", - " def forward(\n", - " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n", - " ) -> torch.Tensor:\n", - " \"\"\"\"\"\"\n", - " if isinstance(x, torch.Tensor):\n", - " x: OptPairTensor = (x, x)\n", - "\n", - " # Node and edge feature dimensionalites need to match.\n", - " if isinstance(edge_index, torch.Tensor):\n", - " assert edge_attr is not None\n", - " assert x[0].size(-1) == edge_attr.size(-1)\n", - " elif isinstance(edge_index, SparseTensor):\n", - " assert x[0].size(-1) == edge_index.size(-1)\n", - "\n", - " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n", - " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n", - "\n", - " x_r = x[1]\n", - " if x_r is not None:\n", - " out += (1 + self.eps) * x_r\n", - "\n", - " return self.nn(out)\n", - "\n", - " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n", - " if self.activation:\n", - " return self.activation(x_j + edge_attr)\n", - " else:\n", - " return x_j + edge_attr\n", - "\n", - " def __repr__(self):\n", - " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n", - "\n", - "\n", - "class GINEncoder(torch.nn.Module):\n", - " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n", - " super().__init__()\n", - "\n", - " self.hidden_dim = hidden_dim\n", - " self.num_convs = num_convs\n", - " self.short_cut = short_cut\n", - " self.concat_hidden = concat_hidden\n", - " self.node_emb = nn.Embedding(100, hidden_dim)\n", - "\n", - " if isinstance(activation, str):\n", - " self.activation = getattr(F, activation)\n", - " else:\n", - " self.activation = None\n", - "\n", - " self.convs = nn.ModuleList()\n", - " for i in range(self.num_convs):\n", - " self.convs.append(\n", - " GINEConv(\n", - " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n", - " activation=activation,\n", - " )\n", - " )\n", - "\n", - " def forward(self, z, edge_index, edge_attr):\n", - " \"\"\"\n", - " Input:\n", - " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n", - " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n", - " Output:\n", - " node_feature: graph feature\n", - " \"\"\"\n", - "\n", - " node_attr = self.node_emb(z) # (num_node, hidden)\n", - "\n", - " hiddens = []\n", - " conv_input = node_attr # (num_node, hidden)\n", - "\n", - " for conv_idx, conv in enumerate(self.convs):\n", - " hidden = conv(conv_input, edge_index, edge_attr)\n", - " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n", - " hidden = self.activation(hidden)\n", - " assert hidden.shape == conv_input.shape\n", - " if self.short_cut and hidden.shape == conv_input.shape:\n", - " hidden += conv_input\n", - "\n", - " hiddens.append(hidden)\n", - " conv_input = hidden\n", - "\n", - " if self.concat_hidden:\n", - " node_feature = torch.cat(hiddens, dim=-1)\n", - " else:\n", - " node_feature = hiddens[-1]\n", - "\n", - " return node_feature\n", - "\n", - "\n", - "class MLPEdgeEncoder(Module):\n", - " def __init__(self, hidden_dim=100, activation=\"relu\"):\n", - " super().__init__()\n", - " self.hidden_dim = hidden_dim\n", - " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n", - " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n", - "\n", - " @property\n", - " def out_channels(self):\n", - " return self.hidden_dim\n", - "\n", - " def forward(self, edge_length, edge_type):\n", - " \"\"\"\n", - " Input:\n", - " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n", - " Returns:\n", - " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n", - " \"\"\"\n", - " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n", - " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n", - " return d_emb * edge_attr # (num_edge, hidden)\n", - "\n", - "\n", - "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n", - " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n", - " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n", - " return h_pair\n", - "\n", - "\n", - "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n", - " \"\"\"\n", - " Args:\n", - " num_nodes: Number of atoms.\n", - " edge_index: Bond indices of the original graph.\n", - " edge_type: Bond types of the original graph.\n", - " order: Extension order.\n", - " Returns:\n", - " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n", - " \"\"\"\n", - "\n", - " def binarize(x):\n", - " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n", - "\n", - " def get_higher_order_adj_matrix(adj, order):\n", - " \"\"\"\n", - " Args:\n", - " adj: (N, N)\n", - " type_mat: (N, N)\n", - " Returns:\n", - " Following attributes will be updated:\n", - " - edge_index\n", - " - edge_type\n", - " Following attributes will be added to the data object:\n", - " - bond_edge_index: Original edge_index.\n", - " \"\"\"\n", - " adj_mats = [\n", - " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n", - " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n", - " ]\n", - "\n", - " for i in range(2, order + 1):\n", - " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n", - " order_mat = torch.zeros_like(adj)\n", - "\n", - " for i in range(1, order + 1):\n", - " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n", - "\n", - " return order_mat\n", - "\n", - " num_types = 22\n", - " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n", - " # from rdkit.Chem.rdchem import BondType as BT\n", - " N = num_nodes\n", - " adj = to_dense_adj(edge_index).squeeze(0)\n", - " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n", - "\n", - " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n", - " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n", - " assert (type_mat * type_highorder == 0).all()\n", - " type_new = type_mat + type_highorder\n", - "\n", - " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n", - " _, edge_order = dense_to_sparse(adj_order)\n", - "\n", - " # data.bond_edge_index = data.edge_index # Save original edges\n", - " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n", - "\n", - " return new_edge_index, new_edge_type\n", - "\n", - "\n", - "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n", - " assert edge_type.dim() == 1\n", - " N = pos.size(0)\n", - "\n", - " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n", - "\n", - " if is_sidechain is None:\n", - " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n", - " else:\n", - " # fetch sidechain and its batch index\n", - " is_sidechain = is_sidechain.bool()\n", - " dummy_index = torch.arange(pos.size(0), device=pos.device)\n", - " sidechain_pos = pos[is_sidechain]\n", - " sidechain_index = dummy_index[is_sidechain]\n", - " sidechain_batch = batch[is_sidechain]\n", - "\n", - " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n", - " r_edge_index_x = assign_index[1]\n", - " r_edge_index_y = assign_index[0]\n", - " r_edge_index_y = sidechain_index[r_edge_index_y]\n", - "\n", - " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n", - " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n", - " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n", - " # delete self loop\n", - " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n", - "\n", - " rgraph_adj = torch.sparse.LongTensor(\n", - " rgraph_edge_index,\n", - " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n", - " torch.Size([N, N]),\n", - " )\n", - "\n", - " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n", - "\n", - " new_edge_index = composed_adj.indices()\n", - " new_edge_type = composed_adj.values().long()\n", - "\n", - " return new_edge_index, new_edge_type\n", - "\n", - "\n", - "def extend_graph_order_radius(\n", - " num_nodes,\n", - " pos,\n", - " edge_index,\n", - " edge_type,\n", - " batch,\n", - " order=3,\n", - " cutoff=10.0,\n", - " extend_order=True,\n", - " extend_radius=True,\n", - " is_sidechain=None,\n", - "):\n", - " if extend_order:\n", - " edge_index, edge_type = _extend_graph_order(\n", - " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n", - " )\n", - "\n", - " if extend_radius:\n", - " edge_index, edge_type = _extend_to_radius_graph(\n", - " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n", - " )\n", - "\n", - " return edge_index, edge_type\n", - "\n", - "\n", - "def get_distance(pos, edge_index):\n", - " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n", - "\n", - "\n", - "def graph_field_network(score_d, pos, edge_index, edge_length):\n", - " \"\"\"\n", - " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n", - " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n", - " \"\"\"\n", - " N = pos.size(0)\n", - " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n", - " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n", - " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n", - " ) # (N, 3)\n", - " return score_pos\n", - "\n", - "\n", - "def clip_norm(vec, limit, p=2):\n", - " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n", - " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n", - " return vec * denom\n", - "\n", - "\n", - "def is_local_edge(edge_type):\n", - " return edge_type > 0\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QWrHJFcYXyUB" - }, - "source": [ - "Main model class!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "MCeZA1qQXzoK" - }, - "outputs": [], - "source": [ - "class MoleculeGNN(ModelMixin, ConfigMixin):\n", - " @register_to_config\n", - " def __init__(\n", - " self,\n", - " hidden_dim=128,\n", - " num_convs=6,\n", - " num_convs_local=4,\n", - " cutoff=10.0,\n", - " mlp_act=\"relu\",\n", - " edge_order=3,\n", - " edge_encoder=\"mlp\",\n", - " smooth_conv=True,\n", - " ):\n", - " super().__init__()\n", - " self.cutoff = cutoff\n", - " self.edge_encoder = edge_encoder\n", - " self.edge_order = edge_order\n", - "\n", - " \"\"\"\n", - " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n", - " in SchNetEncoder\n", - " \"\"\"\n", - " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", - " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", - "\n", - " \"\"\"\n", - " The graph neural network that extracts node-wise features.\n", - " \"\"\"\n", - " self.encoder_global = SchNetEncoder(\n", - " hidden_channels=hidden_dim,\n", - " num_filters=hidden_dim,\n", - " num_interactions=num_convs,\n", - " edge_channels=self.edge_encoder_global.out_channels,\n", - " cutoff=cutoff,\n", - " smooth=smooth_conv,\n", - " )\n", - " self.encoder_local = GINEncoder(\n", - " hidden_dim=hidden_dim,\n", - " num_convs=num_convs_local,\n", - " )\n", - "\n", - " \"\"\"\n", - " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n", - " gradients w.r.t. edge_length (out_dim = 1).\n", - " \"\"\"\n", - " self.grad_global_dist_mlp = MultiLayerPerceptron(\n", - " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", - " )\n", - "\n", - " self.grad_local_dist_mlp = MultiLayerPerceptron(\n", - " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", - " )\n", - "\n", - " \"\"\"\n", - " Incorporate parameters together\n", - " \"\"\"\n", - " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n", - " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n", - "\n", - " def _forward(\n", - " self,\n", - " atom_type,\n", - " pos,\n", - " bond_index,\n", - " bond_type,\n", - " batch,\n", - " time_step, # NOTE, model trained without timestep performed best\n", - " edge_index=None,\n", - " edge_type=None,\n", - " edge_length=None,\n", - " return_edges=False,\n", - " extend_order=True,\n", - " extend_radius=True,\n", - " is_sidechain=None,\n", - " ):\n", - " \"\"\"\n", - " Args:\n", - " atom_type: Types of atoms, (N, ).\n", - " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n", - " bond_type: Bond types, (E, ).\n", - " batch: Node index to graph index, (N, ).\n", - " \"\"\"\n", - " N = atom_type.size(0)\n", - " if edge_index is None or edge_type is None or edge_length is None:\n", - " edge_index, edge_type = extend_graph_order_radius(\n", - " num_nodes=N,\n", - " pos=pos,\n", - " edge_index=bond_index,\n", - " edge_type=bond_type,\n", - " batch=batch,\n", - " order=self.edge_order,\n", - " cutoff=self.cutoff,\n", - " extend_order=extend_order,\n", - " extend_radius=extend_radius,\n", - " is_sidechain=is_sidechain,\n", - " )\n", - " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n", - " local_edge_mask = is_local_edge(edge_type) # (E, )\n", - "\n", - " # with the parameterization of NCSNv2\n", - " # DDPM loss implicit handle the noise variance scale conditioning\n", - " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n", - "\n", - " # Encoding global\n", - " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", - "\n", - " # Global\n", - " node_attr_global = self.encoder_global(\n", - " z=atom_type,\n", - " edge_index=edge_index,\n", - " edge_length=edge_length,\n", - " edge_attr=edge_attr_global,\n", - " )\n", - " # Assemble pairwise features\n", - " h_pair_global = assemble_atom_pair_feature(\n", - " node_attr=node_attr_global,\n", - " edge_index=edge_index,\n", - " edge_attr=edge_attr_global,\n", - " ) # (E_global, 2H)\n", - " # Invariant features of edges (radius graph, global)\n", - " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n", - "\n", - " # Encoding local\n", - " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", - " # edge_attr += temb_edge\n", - "\n", - " # Local\n", - " node_attr_local = self.encoder_local(\n", - " z=atom_type,\n", - " edge_index=edge_index[:, local_edge_mask],\n", - " edge_attr=edge_attr_local[local_edge_mask],\n", - " )\n", - " # Assemble pairwise features\n", - " h_pair_local = assemble_atom_pair_feature(\n", - " node_attr=node_attr_local,\n", - " edge_index=edge_index[:, local_edge_mask],\n", - " edge_attr=edge_attr_local[local_edge_mask],\n", - " ) # (E_local, 2H)\n", - "\n", - " # Invariant features of edges (bond graph, local)\n", - " if isinstance(sigma_edge, torch.Tensor):\n", - " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n", - " 1.0 / sigma_edge[local_edge_mask]\n", - " ) # (E_local, 1)\n", - " else:\n", - " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n", - "\n", - " if return_edges:\n", - " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n", - " else:\n", - " return edge_inv_global, edge_inv_local\n", - "\n", - " def forward(\n", - " self,\n", - " sample,\n", - " timestep: Union[torch.Tensor, float, int],\n", - " return_dict: bool = True,\n", - " sigma=1.0,\n", - " global_start_sigma=0.5,\n", - " w_global=1.0,\n", - " extend_order=False,\n", - " extend_radius=True,\n", - " clip_local=None,\n", - " clip_global=1000.0,\n", - " ) -> Union[MoleculeGNNOutput, Tuple]:\n", - " r\"\"\"\n", - " Args:\n", - " sample: packed torch geometric object\n", - " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n", - " return_dict (`bool`, *optional*, defaults to `True`):\n", - " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n", - " Returns:\n", - " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n", - " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n", - " \"\"\"\n", - "\n", - " # unpack sample\n", - " atom_type = sample.atom_type\n", - " bond_index = sample.edge_index\n", - " bond_type = sample.edge_type\n", - " num_graphs = sample.num_graphs\n", - " pos = sample.pos\n", - "\n", - " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n", - "\n", - " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n", - " atom_type=atom_type,\n", - " pos=sample.pos,\n", - " bond_index=bond_index,\n", - " bond_type=bond_type,\n", - " batch=sample.batch,\n", - " time_step=timesteps,\n", - " return_edges=True,\n", - " extend_order=extend_order,\n", - " extend_radius=extend_radius,\n", - " ) # (E_global, 1), (E_local, 1)\n", - "\n", - " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n", - " node_eq_local = graph_field_network(\n", - " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n", - " )\n", - " if clip_local is not None:\n", - " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n", - "\n", - " # Global\n", - " if sigma < global_start_sigma:\n", - " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n", - " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n", - " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n", - " else:\n", - " node_eq_global = 0\n", - "\n", - " # Sum\n", - " eps_pos = node_eq_local + node_eq_global * w_global\n", - "\n", - " if not return_dict:\n", - " return (-eps_pos,)\n", - "\n", - " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CCIrPYSJj9wd" - }, - "source": [ - "### Load pretrained model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YdrAr6Ch--Ab" - }, - "source": [ - "#### Load a model\n", - "The model used is a design an\n", - "equivariant convolutional layer, named graph field network (GFN).\n", - "\n", - "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 172, - "referenced_widgets": [ - "d90f304e9560472eacfbdd11e46765eb", - "1c6246f15b654f4daa11c9bcf997b78c", - "c2321b3bff6f490ca12040a20308f555", - "b7feb522161f4cf4b7cc7c1a078ff12d", - "e2d368556e494ae7ae4e2e992af2cd4f", - "bbef741e76ec41b7ab7187b487a383df", - "561f742d418d4721b0670cc8dd62e22c", - "872915dd1bb84f538c44e26badabafdd", - "d022575f1fa2446d891650897f187b4d", - "fdc393f3468c432aa0ada05e238a5436", - "2c9362906e4b40189f16d14aa9a348da", - "6010fc8daa7a44d5aec4b830ec2ebaa1", - "7e0bb1b8d65249d3974200686b193be2", - "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", - "6526646be5ed415c84d1245b040e629b", - "24d31fc3576e43dd9f8301d2ef3a37ab", - "2918bfaadc8d4b1a9832522c40dfefb8", - "a4bfdca35cc54dae8812720f1b276a08", - "e4901541199b45c6a18824627692fc39", - "f915cf874246446595206221e900b2fe", - "a9e388f22a9742aaaf538e22575c9433", - "42f6c3db29d7484ba6b4f73590abd2f4" - ] - }, - "id": "DyCo0nsqjbml", - "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45" - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d90f304e9560472eacfbdd11e46765eb", - "version_major": 2, - "version_minor": 0 + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1g_6zOabItDk" }, - "text/plain": [ - "Downloading: 0%| | 0.00/3.27M [00:00] 124.78K 180KB/s in 0.7s \n", - "\n", - "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n", - "\n" - ] - } - ], - "source": [ - "import torch\n", - "\n", - "\n", - "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n", - "dataset = torch.load('/content/molecules.pkl')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QZcmy1EvKQRk" - }, - "source": [ - "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" }, - "id": "JVjz6iH_H6Eh", - "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe" - }, - "outputs": [ { - "data": { - "text/plain": [ - "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")" + "cell_type": "markdown", + "metadata": { + "id": "VfthW90vI0nw" + }, + "source": [ + "Install Conda for some more complex dependencies for geometric networks." ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dataset[0]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vHNiZAUxNgoy" - }, - "source": [ - "## Run the diffusion process" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jZ1KZrxKqENg" - }, - "source": [ - "#### Helper Functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "s240tYueqKKf" - }, - "outputs": [], - "source": [ - "import copy\n", - "import os\n", - "\n", - "from torch_geometric.data import Batch, Data\n", - "from torch_scatter import scatter_mean\n", - "from tqdm import tqdm\n", - "\n", - "\n", - "def repeat_data(data: Data, num_repeat) -> Batch:\n", - " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n", - " return Batch.from_data_list(datas)\n", - "\n", - "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n", - " datas = batch.to_data_list()\n", - " new_data = []\n", - " for i in range(num_repeat):\n", - " new_data += copy.deepcopy(datas)\n", - " return Batch.from_data_list(new_data)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AMnQTk0eqT7Z" - }, - "source": [ - "#### Constants" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "WYGkzqgzrHmF" - }, - "outputs": [], - "source": [ - "num_samples = 1 # solutions per molecule\n", - "num_molecules = 3\n", - "\n", - "DEVICE = 'cuda'\n", - "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n", - "# constants for inference\n", - "w_global = 0.5 #0,.3 for qm9\n", - "global_start_sigma = 0.5\n", - "eta = 1.0\n", - "clip_local = None\n", - "clip_pos = None\n", - "\n", - "# constands for data handling\n", - "save_traj = False\n", - "save_data = False\n", - "output_dir = '/content/'" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-xD5bJ3SqM7t" - }, - "source": [ - "#### Generate samples!\n", - "Note that the 3d representation of a molecule is referred to as the **conformation**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "x9xuLUNg26z1", - "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " after removing the cwd from sys.path.\n", - "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n" - ] - } - ], - "source": [ - "results = []\n", - "\n", - "# define sigmas\n", - "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n", - "sigmas = sigmas.to(DEVICE)\n", - "\n", - "for count, data in enumerate(tqdm(dataset)):\n", - " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n", - "\n", - " data_input = data.clone()\n", - " data_input['pos_ref'] = None\n", - " batch = repeat_data(data_input, num_samples).to(DEVICE)\n", - "\n", - " # initial configuration\n", - " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n", - "\n", - " # for logging animation of denoising\n", - " pos_traj = []\n", - " with torch.no_grad():\n", - "\n", - " # scale initial sample\n", - " pos = pos_init * sigmas[-1]\n", - " for t in scheduler.timesteps:\n", - " batch.pos = pos\n", - "\n", - " # generate geometry with model, then filter it\n", - " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n", - "\n", - " # Update\n", - " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n", - "\n", - " pos = reconstructed_pos\n", - "\n", - " if torch.isnan(pos).any():\n", - " print(\"NaN detected. Please restart.\")\n", - " raise FloatingPointError()\n", - "\n", - " # recenter graph of positions for next iteration\n", - " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n", - "\n", - " # optional clipping\n", - " if clip_pos is not None:\n", - " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n", - " pos_traj.append(pos.clone().cpu())\n", - "\n", - " pos_gen = pos.cpu()\n", - " if save_traj:\n", - " pos_gen_traj = pos_traj.cpu()\n", - " data.pos_gen = torch.stack(pos_gen_traj)\n", - " else:\n", - " data.pos_gen = pos_gen\n", - " results.append(data)\n", - "\n", - "\n", - "if save_data:\n", - " save_path = os.path.join(output_dir, 'samples_all.pkl')\n", - "\n", - " with open(save_path, 'wb') as f:\n", - " pickle.dump(results, f)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fSApwSaZNndW" - }, - "source": [ - "## Render the results!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d47Zxo2OKdgZ" - }, - "source": [ - "This function allows us to render 3d in colab." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "e9Cd0kCAv9b8" - }, - "outputs": [], - "source": [ - "from google.colab import output\n", - "\n", - "\n", - "output.enable_custom_widget_manager()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RjaVuR15NqzF" - }, - "source": [ - "### Helper functions" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "28rBYa9NKhlz" - }, - "source": [ - "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "LKdKdwxcyTQ6" - }, - "outputs": [], - "source": [ - "from copy import deepcopy\n", - "\n", - "\n", - "def set_rdmol_positions(rdkit_mol, pos):\n", - " \"\"\"\n", - " Args:\n", - " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", - " pos: (N_atoms, 3)\n", - " \"\"\"\n", - " mol = deepcopy(rdkit_mol)\n", - " set_rdmol_positions_(mol, pos)\n", - " return mol\n", - "\n", - "def set_rdmol_positions_(mol, pos):\n", - " \"\"\"\n", - " Args:\n", - " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", - " pos: (N_atoms, 3)\n", - " \"\"\"\n", - " for i in range(pos.shape[0]):\n", - " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n", - " return mol\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NuE10hcpKmzK" - }, - "source": [ - "Process the generated data to make it easy to view." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "KieVE1vc0_Vs", - "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "collect 5 generated molecules in `mols`\n" - ] - } - ], - "source": [ - "# the model can generate multiple conformations per 2d geometry\n", - "num_gen = results[0]['pos_gen'].shape[0]\n", - "\n", - "# init storage objects\n", - "mols_gen = []\n", - "mols_orig = []\n", - "for to_process in results:\n", - "\n", - " # store the reference 3d position\n", - " to_process['pos_ref'] = to_process['pos_ref'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", - "\n", - " # store the generated 3d position\n", - " to_process['pos_gen'] = to_process['pos_gen'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", - "\n", - " # copy data to new object\n", - " new_mol = set_rdmol_positions(to_process.rdmol, to_process['pos_gen'][0])\n", - "\n", - " # append results\n", - " mols_gen.append(new_mol)\n", - " mols_orig.append(to_process.rdmol)\n", - "\n", - "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tin89JwMKp4v" - }, - "source": [ - "Import tools to visualize the 2d chemical diagram of the molecule." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yqV6gllSZn38" - }, - "outputs": [], - "source": [ - "from IPython.display import SVG, display\n", - "from rdkit import Chem\n", - "from rdkit.Chem.Draw import rdMolDraw2D as MD2" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TFNKmGddVoOk" - }, - "source": [ - "Select molecule to visualize" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KzuwLlrrVaGc" - }, - "outputs": [], - "source": [ - "idx = 0\n", - "assert idx < len(results), \"selected molecule that was not generated\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hkb8w0_SNtU8" - }, - "source": [ - "### Viewing" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "I3R4QBQeKttN" - }, - "source": [ - "This 2D rendering is the equivalent of the **input to the model**!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 321 - }, - "id": "gkQRWjraaKex", - "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47" - }, - "outputs": [ - { - "data": { - "image/svg+xml": "\n\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n", - "text/plain": [ - "" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2WNFzSnbiE0k", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install -q condacolab" ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n", - "molSize=(450,300)\n", - "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n", - "drawer.DrawMolecule(mc)\n", - "drawer.FinishDrawing()\n", - "svg = drawer.GetDrawingText()\n", - "display(SVG(svg.replace('svg:','')))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "z4FDMYMxKw2I" - }, - "source": [ - "Generate the 3d molecule!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 17, - "referenced_widgets": [ - "695ab5bbf30a4ab19df1f9f33469f314", - "eac6a8dcdc9d4335a2e51031793ead29" - ] - }, - "id": "aT1Bkb8YxJfV", - "outputId": "b98870ae-049d-4386-b676-166e9526bda2" - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "695ab5bbf30a4ab19df1f9f33469f314", - "version_major": 2, - "version_minor": 0 + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NUsbWYCUI7Km" }, - "text/plain": [] - }, - "metadata": { - "application/vnd.jupyter.widget-view+json": { - "colab": { - "custom_widget_manager": { - "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" + "source": [ + "Setup Conda" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FZelreINdmd0", + "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "✨🍰✨ Everything looks OK!\n" + ] } - } - } - }, - "output_type": "display_data" - } - ], - "source": [ - "from nglview import show_rdkit as show" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 337, - "referenced_widgets": [ - "be446195da2b4ff2aec21ec5ff963a54", - "c6596896148b4a8a9c57963b67c7782f", - "2489b5e5648541fbbdceadb05632a050", - "01e0ba4e5da04914b4652b8d58565d7b", - "c30e6c2f3e2a44dbbb3d63bd519acaa4", - "f31c6e40e9b2466a9064a2669933ecd5", - "19308ccac642498ab8b58462e3f1b0bb", - "4a081cdc2ec3421ca79dd933b7e2b0c4", - "e5c0d75eb5e1447abd560c8f2c6017e1", - "5146907ef6764654ad7d598baebc8b58", - "144ec959b7604a2cabb5ca46ae5e5379", - "abce2a80e6304df3899109c6d6cac199", - "65195cb7a4134f4887e9dd19f3676462" - ] - }, - "id": "pxtq8I-I18C-", - "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7" - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "be446195da2b4ff2aec21ec5ff963a54", - "version_major": 2, - "version_minor": 0 + ], + "source": [ + "import condacolab\n", + "condacolab.install()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JzDHaPU7I9Sn" }, - "text/plain": [ - "NGLWidget()" + "source": [ + "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)" ] - }, - "metadata": { - "application/vnd.jupyter.widget-view+json": { - "colab": { - "custom_widget_manager": { - "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JMxRjHhL7w8V", + "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n", + "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "\n", + "## Package Plan ##\n", + "\n", + " environment location: /usr/local\n", + "\n", + " added / updated specs:\n", + " - cudatoolkit=11.1\n", + " - pytorch\n", + " - torchaudio\n", + " - torchvision\n", + "\n", + "\n", + "The following packages will be downloaded:\n", + "\n", + " package | build\n", + " ---------------------------|-----------------\n", + " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n", + " ------------------------------------------------------------\n", + " Total: 960 KB\n", + "\n", + "The following packages will be UPDATED:\n", + "\n", + " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n", + "\n", + "\n", + "\n", + "Downloading and Extracting Packages\n", + "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n", + "Preparing transaction: / \b\bdone\n", + "Verifying transaction: \\ \b\bdone\n", + "Executing transaction: / \b\bdone\n", + "Retrieving notices: ...working... done\n" + ] } - } + ], + "source": [ + "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n", + "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Need to remove a pathspec for colab that specifies the incorrect cuda version." + ], + "metadata": { + "id": "QDS6FPZ0Tu5b" } - }, - "output_type": "display_data" - } - ], - "source": [ - "# new molecule\n", - "show(mols_gen[idx])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KJr4h2mwXeTo" - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "01e0ba4e5da04914b4652b8d58565d7b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1", - "IPY_MODEL_5146907ef6764654ad7d598baebc8b58" + }, + { + "cell_type": "code", + "source": [ + "!rm /usr/local/conda-meta/pinned" ], - "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379" - } - }, - "144ec959b7604a2cabb5ca46ae5e5379": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "19308ccac642498ab8b58462e3f1b0bb": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "1c6246f15b654f4daa11c9bcf997b78c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df", - "placeholder": "​", - "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c", - "value": "Downloading: 100%" - } - }, - "2489b5e5648541fbbdceadb05632a050": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ButtonModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ButtonModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ButtonView", - "button_style": "", - "description": "", - "disabled": false, - "icon": "compress", - "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199", - "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462", - "tooltip": "" - } - }, - "24d31fc3576e43dd9f8301d2ef3a37ab": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2918bfaadc8d4b1a9832522c40dfefb8": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2c9362906e4b40189f16d14aa9a348da": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "42f6c3db29d7484ba6b4f73590abd2f4": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "4a081cdc2ec3421ca79dd933b7e2b0c4": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "SliderStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "SliderStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "", - "handle_color": null - } - }, - "5146907ef6764654ad7d598baebc8b58": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "IntSliderModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "IntSliderModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "IntSliderView", - "continuous_update": true, - "description": "", - "description_tooltip": null, - "disabled": false, - "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb", - "max": 0, - "min": 0, - "orientation": "horizontal", - "readout": true, - "readout_format": "d", - "step": 1, - "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4", - "value": 0 - } - }, - "561f742d418d4721b0670cc8dd62e22c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "6010fc8daa7a44d5aec4b830ec2ebaa1": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2", - "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", - "IPY_MODEL_6526646be5ed415c84d1245b040e629b" + "metadata": { + "id": "dq1lxR10TtrR", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z1L3DdZOJB30" + }, + "source": [ + "Install torch geometric (used in the model later)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "D5ukfCOWfjzK", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", + "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "\n", + "## Package Plan ##\n", + "\n", + " environment location: /usr/local\n", + "\n", + " added / updated specs:\n", + " - pytorch-geometric=1.7.2\n", + "\n", + "\n", + "The following packages will be downloaded:\n", + "\n", + " package | build\n", + " ---------------------------|-----------------\n", + " decorator-4.4.2 | py_0 11 KB conda-forge\n", + " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n", + " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n", + " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n", + " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n", + " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n", + " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n", + " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n", + " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n", + " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n", + " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n", + " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n", + " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n", + " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n", + " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n", + " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n", + " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n", + " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n", + " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n", + " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n", + " ------------------------------------------------------------\n", + " Total: 55.9 MB\n", + "\n", + "The following NEW packages will be INSTALLED:\n", + "\n", + " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n", + " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n", + " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n", + " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n", + " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n", + " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n", + " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n", + " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n", + " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n", + " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n", + " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n", + " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n", + " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n", + " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n", + " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n", + " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n", + " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n", + " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n", + " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n", + "\n", + "The following packages will be DOWNGRADED:\n", + "\n", + " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n", + "\n", + "\n", + "\n", + "Downloading and Extracting Packages\n", + "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n", + "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n", + "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n", + "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n", + "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n", + "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n", + "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n", + "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n", + "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n", + "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n", + "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n", + "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n", + "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n", + "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n", + "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n", + "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n", + "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n", + "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n", + "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n", + "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n", + "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n", + "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", + "Retrieving notices: ...working... done\n" + ] + } ], - "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab" - } - }, - "65195cb7a4134f4887e9dd19f3676462": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ButtonStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ButtonStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "button_color": null, - "font_weight": "" - } - }, - "6526646be5ed415c84d1245b040e629b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433", - "placeholder": "​", - "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4", - "value": " 401/401 [00:00<00:00, 13.5kB/s]" - } - }, - "695ab5bbf30a4ab19df1f9f33469f314": { - "model_module": "nglview-js-widgets", - "model_module_version": "3.0.1", - "model_name": "ColormakerRegistryModel", - "state": { - "_dom_classes": [], - "_model_module": "nglview-js-widgets", - "_model_module_version": "3.0.1", - "_model_name": "ColormakerRegistryModel", - "_msg_ar": [], - "_msg_q": [], - "_ready": false, - "_view_count": null, - "_view_module": "nglview-js-widgets", - "_view_module_version": "3.0.1", - "_view_name": "ColormakerRegistryView", - "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29" - } - }, - "7e0bb1b8d65249d3974200686b193be2": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8", - "placeholder": "​", - "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08", - "value": "Downloading: 100%" - } - }, - "872915dd1bb84f538c44e26badabafdd": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "a4bfdca35cc54dae8812720f1b276a08": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "a9e388f22a9742aaaf538e22575c9433": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "abce2a80e6304df3899109c6d6cac199": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": "34px" - } - }, - "b7feb522161f4cf4b7cc7c1a078ff12d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436", - "placeholder": "​", - "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da", - "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]" - } - }, - "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39", - "max": 401, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_f915cf874246446595206221e900b2fe", - "value": 401 - } - }, - "bbef741e76ec41b7ab7187b487a383df": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "be446195da2b4ff2aec21ec5ff963a54": { - "model_module": "nglview-js-widgets", - "model_module_version": "3.0.1", - "model_name": "NGLModel", - "state": { - "_camera_orientation": [ - -15.519693580202304, - -14.065056548036177, - -23.53197484807691, - 0, - -23.357853515109753, - 20.94055073042662, - 2.888695042134944, - 0, - 14.352363398292775, - 18.870825741878015, - -20.744689572909344, - 0, - 0.2724999189376831, - 0.6940000057220459, - -0.3734999895095825, - 1 + "source": [ + "!conda install -c rusty1s pytorch-geometric=1.7.2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ppxv6Mdkalbc" + }, + "source": [ + "### Install Diffusers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mgQA_XN-XGY2", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "85392615-b6a4-4052-9d2a-79604be62c94" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content\n", + "Cloning into 'diffusers'...\n", + "remote: Enumerating objects: 9298, done.\u001b[K\n", + "remote: Counting objects: 100% (40/40), done.\u001b[K\n", + "remote: Compressing objects: 100% (23/23), done.\u001b[K\n", + "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n", + "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n", + "Resolving deltas: 100% (6168/6168), done.\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } ], - "_camera_str": "orthographic", - "_dom_classes": [], - "_gui_theme": null, - "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050", - "_igui": null, - "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b", - "_model_module": "nglview-js-widgets", - "_model_module_version": "3.0.1", - "_model_name": "NGLModel", - "_ngl_color_dict": {}, - "_ngl_coordinate_resource": {}, - "_ngl_full_stage_parameters": { - "ambientColor": 14540253, - "ambientIntensity": 0.2, - "backgroundColor": "white", - "cameraEyeSep": 0.3, - "cameraFov": 40, - "cameraType": "perspective", - "clipDist": 10, - "clipFar": 100, - "clipNear": 0, - "fogFar": 100, - "fogNear": 50, - "hoverTimeout": 0, - "impostor": true, - "lightColor": 14540253, - "lightIntensity": 1, - "mousePreset": "default", - "panSpeed": 1, - "quality": "medium", - "rotateSpeed": 2, - "sampleLevel": 0, - "tooltip": true, - "workerDefault": true, - "zoomSpeed": 1.2 + "source": [ + "%cd /content\n", + "\n", + "# install latest HF diffusers (will update to the release once added)\n", + "!git clone https://github.com/huggingface/diffusers.git\n", + "!pip install -q /content/diffusers\n", + "\n", + "# dependencies for diffusers\n", + "!pip install -q datasets transformers" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LZO6AJKuJKO8" }, - "_ngl_msg_archive": [ - { - "args": [ - { - "binary": false, - "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n", - "type": "blob" - } - ], - "kwargs": { - "defaultRepresentation": true, - "ext": "pdb" + "source": [ + "Check that torch is installed correctly and utilizing the GPU in the colab" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gZt7BNi1e1PA", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 53 + }, + "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "True\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "'1.8.2'" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + } + }, + "metadata": {}, + "execution_count": 8 + } + ], + "source": [ + "import torch\n", + "print(torch.cuda.is_available())\n", + "torch.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KLE7CqlfJNUO" + }, + "source": [ + "### Install Chemistry-specific Dependencies\n", + "\n", + "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0CPv_NvehRz3", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting rdkit\n", + " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n", + "Installing collected packages: rdkit\n", + "Successfully installed rdkit-2022.3.5\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install rdkit" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "88GaDbDPxJ5I" + }, + "source": [ + "### Get viewer from nglview\n", + "\n", + "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n", + "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n", + "The rdmol in this object is a source of ground truth for the generated molecules.\n", + "\n", + "You will use one rendering function from nglviewer later!\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jcl8GCS2mz6t", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting nglview\n", + " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n", + "Collecting jupyterlab-widgets\n", + " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipywidgets>=7\n", + " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting widgetsnbextension~=4.0\n", + " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipython>=6.1.0\n", + " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipykernel>=4.5.1\n", + " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting traitlets>=4.3.1\n", + " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n", + "Collecting pyzmq>=17\n", + " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting matplotlib-inline>=0.1\n", + " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n", + "Collecting tornado>=6.1\n", + " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nest-asyncio\n", + " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n", + "Collecting debugpy>=1.0\n", + " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting psutil\n", + " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting jupyter-client>=6.1.12\n", + " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pickleshare\n", + " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n", + "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n", + "Collecting backcall\n", + " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n", + "Collecting pexpect>4.3\n", + " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pygments\n", + " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting jedi>=0.16\n", + " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n", + " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n", + "Collecting parso<0.9.0,>=0.8.0\n", + " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n", + "Collecting entrypoints\n", + " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n", + "Collecting jupyter-core>=4.9.2\n", + " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ptyprocess>=0.5\n", + " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n", + "Collecting wcwidth\n", + " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n", + "Building wheels for collected packages: nglview\n", + " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n", + " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n", + "Successfully built nglview\n", + "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n", + "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] }, - "methodName": "loadFile", - "reconstruc_color_scheme": false, - "target": "Stage", - "type": "call_method" - } + { + "output_type": "display_data", + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "pexpect", + "pickleshare", + "wcwidth" + ] + } + } + }, + "metadata": {} + } + ], + "source": [ + "!pip install nglview" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Create a diffusion model" + ], + "metadata": { + "id": "8t8_e_uVLdKB" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Model class(es)" + ], + "metadata": { + "id": "G0rMncVtNSqU" + } + }, + { + "cell_type": "markdown", + "source": [ + "Imports" + ], + "metadata": { + "id": "L5FEXz5oXkzt" + } + }, + { + "cell_type": "code", + "source": [ + "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n", + "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n", + "from dataclasses import dataclass\n", + "from typing import Callable, Tuple, Union\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import Tensor, nn\n", + "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n", + "\n", + "from torch_geometric.nn import MessagePassing, radius, radius_graph\n", + "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n", + "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n", + "from torch_scatter import scatter_add\n", + "from torch_sparse import SparseTensor, coalesce\n", + "\n", + "from diffusers.configuration_utils import ConfigMixin, register_to_config\n", + "from diffusers.modeling_utils import ModelMixin\n", + "from diffusers.utils import BaseOutput\n" + ], + "metadata": { + "id": "-3-P4w5sXkRU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Helper classes" + ], + "metadata": { + "id": "EzJQXPN_XrMX" + } + }, + { + "cell_type": "code", + "source": [ + "@dataclass\n", + "class MoleculeGNNOutput(BaseOutput):\n", + " \"\"\"\n", + " Args:\n", + " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n", + " Hidden states output. Output of last layer of model.\n", + " \"\"\"\n", + "\n", + " sample: torch.Tensor\n", + "\n", + "\n", + "class MultiLayerPerceptron(nn.Module):\n", + " \"\"\"\n", + " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n", + " Args:\n", + " input_dim (int): input dimension\n", + " hidden_dim (list of int): hidden dimensions\n", + " activation (str or function, optional): activation function\n", + " dropout (float, optional): dropout rate\n", + " \"\"\"\n", + "\n", + " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n", + " super(MultiLayerPerceptron, self).__init__()\n", + "\n", + " self.dims = [input_dim] + hidden_dims\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " print(f\"Warning, activation passed {activation} is not string and ignored\")\n", + " self.activation = None\n", + " if dropout > 0:\n", + " self.dropout = nn.Dropout(dropout)\n", + " else:\n", + " self.dropout = None\n", + "\n", + " self.layers = nn.ModuleList()\n", + " for i in range(len(self.dims) - 1):\n", + " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n", + "\n", + " def forward(self, x):\n", + " \"\"\"\"\"\"\n", + " for i, layer in enumerate(self.layers):\n", + " x = layer(x)\n", + " if i < len(self.layers) - 1:\n", + " if self.activation:\n", + " x = self.activation(x)\n", + " if self.dropout:\n", + " x = self.dropout(x)\n", + " return x\n", + "\n", + "\n", + "class ShiftedSoftplus(torch.nn.Module):\n", + " def __init__(self):\n", + " super(ShiftedSoftplus, self).__init__()\n", + " self.shift = torch.log(torch.tensor(2.0)).item()\n", + "\n", + " def forward(self, x):\n", + " return F.softplus(x) - self.shift\n", + "\n", + "\n", + "class CFConv(MessagePassing):\n", + " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n", + " super(CFConv, self).__init__(aggr=\"add\")\n", + " self.lin1 = Linear(in_channels, num_filters, bias=False)\n", + " self.lin2 = Linear(num_filters, out_channels)\n", + " self.nn = mlp\n", + " self.cutoff = cutoff\n", + " self.smooth = smooth\n", + "\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " torch.nn.init.xavier_uniform_(self.lin1.weight)\n", + " torch.nn.init.xavier_uniform_(self.lin2.weight)\n", + " self.lin2.bias.data.fill_(0)\n", + "\n", + " def forward(self, x, edge_index, edge_length, edge_attr):\n", + " if self.smooth:\n", + " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n", + " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n", + " else:\n", + " C = (edge_length <= self.cutoff).float()\n", + " W = self.nn(edge_attr) * C.view(-1, 1)\n", + "\n", + " x = self.lin1(x)\n", + " x = self.propagate(edge_index, x=x, W=W)\n", + " x = self.lin2(x)\n", + " return x\n", + "\n", + " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n", + " return x_j * W\n", + "\n", + "\n", + "class InteractionBlock(torch.nn.Module):\n", + " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n", + " super(InteractionBlock, self).__init__()\n", + " mlp = Sequential(\n", + " Linear(num_gaussians, num_filters),\n", + " ShiftedSoftplus(),\n", + " Linear(num_filters, num_filters),\n", + " )\n", + " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n", + " self.act = ShiftedSoftplus()\n", + " self.lin = Linear(hidden_channels, hidden_channels)\n", + "\n", + " def forward(self, x, edge_index, edge_length, edge_attr):\n", + " x = self.conv(x, edge_index, edge_length, edge_attr)\n", + " x = self.act(x)\n", + " x = self.lin(x)\n", + " return x\n", + "\n", + "\n", + "class SchNetEncoder(Module):\n", + " def __init__(\n", + " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.hidden_channels = hidden_channels\n", + " self.num_filters = num_filters\n", + " self.num_interactions = num_interactions\n", + " self.cutoff = cutoff\n", + "\n", + " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n", + "\n", + " self.interactions = ModuleList()\n", + " for _ in range(num_interactions):\n", + " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n", + " self.interactions.append(block)\n", + "\n", + " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n", + " if embed_node:\n", + " assert z.dim() == 1 and z.dtype == torch.long\n", + " h = self.embedding(z)\n", + " else:\n", + " h = z\n", + " for interaction in self.interactions:\n", + " h = h + interaction(h, edge_index, edge_length, edge_attr)\n", + "\n", + " return h\n", + "\n", + "\n", + "class GINEConv(MessagePassing):\n", + " \"\"\"\n", + " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n", + " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n", + " \"\"\"\n", + "\n", + " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n", + " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n", + " self.nn = mlp\n", + " self.initial_eps = eps\n", + "\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " self.activation = None\n", + "\n", + " if train_eps:\n", + " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n", + " else:\n", + " self.register_buffer(\"eps\", torch.Tensor([eps]))\n", + "\n", + " def forward(\n", + " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n", + " ) -> torch.Tensor:\n", + " \"\"\"\"\"\"\n", + " if isinstance(x, torch.Tensor):\n", + " x: OptPairTensor = (x, x)\n", + "\n", + " # Node and edge feature dimensionalites need to match.\n", + " if isinstance(edge_index, torch.Tensor):\n", + " assert edge_attr is not None\n", + " assert x[0].size(-1) == edge_attr.size(-1)\n", + " elif isinstance(edge_index, SparseTensor):\n", + " assert x[0].size(-1) == edge_index.size(-1)\n", + "\n", + " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n", + " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n", + "\n", + " x_r = x[1]\n", + " if x_r is not None:\n", + " out += (1 + self.eps) * x_r\n", + "\n", + " return self.nn(out)\n", + "\n", + " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n", + " if self.activation:\n", + " return self.activation(x_j + edge_attr)\n", + " else:\n", + " return x_j + edge_attr\n", + "\n", + " def __repr__(self):\n", + " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n", + "\n", + "\n", + "class GINEncoder(torch.nn.Module):\n", + " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n", + " super().__init__()\n", + "\n", + " self.hidden_dim = hidden_dim\n", + " self.num_convs = num_convs\n", + " self.short_cut = short_cut\n", + " self.concat_hidden = concat_hidden\n", + " self.node_emb = nn.Embedding(100, hidden_dim)\n", + "\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " self.activation = None\n", + "\n", + " self.convs = nn.ModuleList()\n", + " for i in range(self.num_convs):\n", + " self.convs.append(\n", + " GINEConv(\n", + " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n", + " activation=activation,\n", + " )\n", + " )\n", + "\n", + " def forward(self, z, edge_index, edge_attr):\n", + " \"\"\"\n", + " Input:\n", + " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n", + " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n", + " Output:\n", + " node_feature: graph feature\n", + " \"\"\"\n", + "\n", + " node_attr = self.node_emb(z) # (num_node, hidden)\n", + "\n", + " hiddens = []\n", + " conv_input = node_attr # (num_node, hidden)\n", + "\n", + " for conv_idx, conv in enumerate(self.convs):\n", + " hidden = conv(conv_input, edge_index, edge_attr)\n", + " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n", + " hidden = self.activation(hidden)\n", + " assert hidden.shape == conv_input.shape\n", + " if self.short_cut and hidden.shape == conv_input.shape:\n", + " hidden += conv_input\n", + "\n", + " hiddens.append(hidden)\n", + " conv_input = hidden\n", + "\n", + " if self.concat_hidden:\n", + " node_feature = torch.cat(hiddens, dim=-1)\n", + " else:\n", + " node_feature = hiddens[-1]\n", + "\n", + " return node_feature\n", + "\n", + "\n", + "class MLPEdgeEncoder(Module):\n", + " def __init__(self, hidden_dim=100, activation=\"relu\"):\n", + " super().__init__()\n", + " self.hidden_dim = hidden_dim\n", + " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n", + " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n", + "\n", + " @property\n", + " def out_channels(self):\n", + " return self.hidden_dim\n", + "\n", + " def forward(self, edge_length, edge_type):\n", + " \"\"\"\n", + " Input:\n", + " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n", + " Returns:\n", + " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n", + " \"\"\"\n", + " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n", + " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n", + " return d_emb * edge_attr # (num_edge, hidden)\n", + "\n", + "\n", + "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n", + " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n", + " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n", + " return h_pair\n", + "\n", + "\n", + "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n", + " \"\"\"\n", + " Args:\n", + " num_nodes: Number of atoms.\n", + " edge_index: Bond indices of the original graph.\n", + " edge_type: Bond types of the original graph.\n", + " order: Extension order.\n", + " Returns:\n", + " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n", + " \"\"\"\n", + "\n", + " def binarize(x):\n", + " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n", + "\n", + " def get_higher_order_adj_matrix(adj, order):\n", + " \"\"\"\n", + " Args:\n", + " adj: (N, N)\n", + " type_mat: (N, N)\n", + " Returns:\n", + " Following attributes will be updated:\n", + " - edge_index\n", + " - edge_type\n", + " Following attributes will be added to the data object:\n", + " - bond_edge_index: Original edge_index.\n", + " \"\"\"\n", + " adj_mats = [\n", + " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n", + " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n", + " ]\n", + "\n", + " for i in range(2, order + 1):\n", + " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n", + " order_mat = torch.zeros_like(adj)\n", + "\n", + " for i in range(1, order + 1):\n", + " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n", + "\n", + " return order_mat\n", + "\n", + " num_types = 22\n", + " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n", + " # from rdkit.Chem.rdchem import BondType as BT\n", + " N = num_nodes\n", + " adj = to_dense_adj(edge_index).squeeze(0)\n", + " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n", + "\n", + " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n", + " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n", + " assert (type_mat * type_highorder == 0).all()\n", + " type_new = type_mat + type_highorder\n", + "\n", + " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n", + " _, edge_order = dense_to_sparse(adj_order)\n", + "\n", + " # data.bond_edge_index = data.edge_index # Save original edges\n", + " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n", + "\n", + " return new_edge_index, new_edge_type\n", + "\n", + "\n", + "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n", + " assert edge_type.dim() == 1\n", + " N = pos.size(0)\n", + "\n", + " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n", + "\n", + " if is_sidechain is None:\n", + " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n", + " else:\n", + " # fetch sidechain and its batch index\n", + " is_sidechain = is_sidechain.bool()\n", + " dummy_index = torch.arange(pos.size(0), device=pos.device)\n", + " sidechain_pos = pos[is_sidechain]\n", + " sidechain_index = dummy_index[is_sidechain]\n", + " sidechain_batch = batch[is_sidechain]\n", + "\n", + " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n", + " r_edge_index_x = assign_index[1]\n", + " r_edge_index_y = assign_index[0]\n", + " r_edge_index_y = sidechain_index[r_edge_index_y]\n", + "\n", + " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n", + " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n", + " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n", + " # delete self loop\n", + " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n", + "\n", + " rgraph_adj = torch.sparse.LongTensor(\n", + " rgraph_edge_index,\n", + " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n", + " torch.Size([N, N]),\n", + " )\n", + "\n", + " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n", + "\n", + " new_edge_index = composed_adj.indices()\n", + " new_edge_type = composed_adj.values().long()\n", + "\n", + " return new_edge_index, new_edge_type\n", + "\n", + "\n", + "def extend_graph_order_radius(\n", + " num_nodes,\n", + " pos,\n", + " edge_index,\n", + " edge_type,\n", + " batch,\n", + " order=3,\n", + " cutoff=10.0,\n", + " extend_order=True,\n", + " extend_radius=True,\n", + " is_sidechain=None,\n", + "):\n", + " if extend_order:\n", + " edge_index, edge_type = _extend_graph_order(\n", + " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n", + " )\n", + "\n", + " if extend_radius:\n", + " edge_index, edge_type = _extend_to_radius_graph(\n", + " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n", + " )\n", + "\n", + " return edge_index, edge_type\n", + "\n", + "\n", + "def get_distance(pos, edge_index):\n", + " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n", + "\n", + "\n", + "def graph_field_network(score_d, pos, edge_index, edge_length):\n", + " \"\"\"\n", + " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n", + " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n", + " \"\"\"\n", + " N = pos.size(0)\n", + " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n", + " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n", + " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n", + " ) # (N, 3)\n", + " return score_pos\n", + "\n", + "\n", + "def clip_norm(vec, limit, p=2):\n", + " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n", + " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n", + " return vec * denom\n", + "\n", + "\n", + "def is_local_edge(edge_type):\n", + " return edge_type > 0\n" + ], + "metadata": { + "id": "oR1Y56QiLY90" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Main model class!" + ], + "metadata": { + "id": "QWrHJFcYXyUB" + } + }, + { + "cell_type": "code", + "source": [ + "class MoleculeGNN(ModelMixin, ConfigMixin):\n", + " @register_to_config\n", + " def __init__(\n", + " self,\n", + " hidden_dim=128,\n", + " num_convs=6,\n", + " num_convs_local=4,\n", + " cutoff=10.0,\n", + " mlp_act=\"relu\",\n", + " edge_order=3,\n", + " edge_encoder=\"mlp\",\n", + " smooth_conv=True,\n", + " ):\n", + " super().__init__()\n", + " self.cutoff = cutoff\n", + " self.edge_encoder = edge_encoder\n", + " self.edge_order = edge_order\n", + "\n", + " \"\"\"\n", + " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n", + " in SchNetEncoder\n", + " \"\"\"\n", + " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", + " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", + "\n", + " \"\"\"\n", + " The graph neural network that extracts node-wise features.\n", + " \"\"\"\n", + " self.encoder_global = SchNetEncoder(\n", + " hidden_channels=hidden_dim,\n", + " num_filters=hidden_dim,\n", + " num_interactions=num_convs,\n", + " edge_channels=self.edge_encoder_global.out_channels,\n", + " cutoff=cutoff,\n", + " smooth=smooth_conv,\n", + " )\n", + " self.encoder_local = GINEncoder(\n", + " hidden_dim=hidden_dim,\n", + " num_convs=num_convs_local,\n", + " )\n", + "\n", + " \"\"\"\n", + " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n", + " gradients w.r.t. edge_length (out_dim = 1).\n", + " \"\"\"\n", + " self.grad_global_dist_mlp = MultiLayerPerceptron(\n", + " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", + " )\n", + "\n", + " self.grad_local_dist_mlp = MultiLayerPerceptron(\n", + " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", + " )\n", + "\n", + " \"\"\"\n", + " Incorporate parameters together\n", + " \"\"\"\n", + " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n", + " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n", + "\n", + " def _forward(\n", + " self,\n", + " atom_type,\n", + " pos,\n", + " bond_index,\n", + " bond_type,\n", + " batch,\n", + " time_step, # NOTE, model trained without timestep performed best\n", + " edge_index=None,\n", + " edge_type=None,\n", + " edge_length=None,\n", + " return_edges=False,\n", + " extend_order=True,\n", + " extend_radius=True,\n", + " is_sidechain=None,\n", + " ):\n", + " \"\"\"\n", + " Args:\n", + " atom_type: Types of atoms, (N, ).\n", + " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n", + " bond_type: Bond types, (E, ).\n", + " batch: Node index to graph index, (N, ).\n", + " \"\"\"\n", + " N = atom_type.size(0)\n", + " if edge_index is None or edge_type is None or edge_length is None:\n", + " edge_index, edge_type = extend_graph_order_radius(\n", + " num_nodes=N,\n", + " pos=pos,\n", + " edge_index=bond_index,\n", + " edge_type=bond_type,\n", + " batch=batch,\n", + " order=self.edge_order,\n", + " cutoff=self.cutoff,\n", + " extend_order=extend_order,\n", + " extend_radius=extend_radius,\n", + " is_sidechain=is_sidechain,\n", + " )\n", + " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n", + " local_edge_mask = is_local_edge(edge_type) # (E, )\n", + "\n", + " # with the parameterization of NCSNv2\n", + " # DDPM loss implicit handle the noise variance scale conditioning\n", + " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n", + "\n", + " # Encoding global\n", + " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", + "\n", + " # Global\n", + " node_attr_global = self.encoder_global(\n", + " z=atom_type,\n", + " edge_index=edge_index,\n", + " edge_length=edge_length,\n", + " edge_attr=edge_attr_global,\n", + " )\n", + " # Assemble pairwise features\n", + " h_pair_global = assemble_atom_pair_feature(\n", + " node_attr=node_attr_global,\n", + " edge_index=edge_index,\n", + " edge_attr=edge_attr_global,\n", + " ) # (E_global, 2H)\n", + " # Invariant features of edges (radius graph, global)\n", + " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n", + "\n", + " # Encoding local\n", + " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", + " # edge_attr += temb_edge\n", + "\n", + " # Local\n", + " node_attr_local = self.encoder_local(\n", + " z=atom_type,\n", + " edge_index=edge_index[:, local_edge_mask],\n", + " edge_attr=edge_attr_local[local_edge_mask],\n", + " )\n", + " # Assemble pairwise features\n", + " h_pair_local = assemble_atom_pair_feature(\n", + " node_attr=node_attr_local,\n", + " edge_index=edge_index[:, local_edge_mask],\n", + " edge_attr=edge_attr_local[local_edge_mask],\n", + " ) # (E_local, 2H)\n", + "\n", + " # Invariant features of edges (bond graph, local)\n", + " if isinstance(sigma_edge, torch.Tensor):\n", + " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n", + " 1.0 / sigma_edge[local_edge_mask]\n", + " ) # (E_local, 1)\n", + " else:\n", + " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n", + "\n", + " if return_edges:\n", + " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n", + " else:\n", + " return edge_inv_global, edge_inv_local\n", + "\n", + " def forward(\n", + " self,\n", + " sample,\n", + " timestep: Union[torch.Tensor, float, int],\n", + " return_dict: bool = True,\n", + " sigma=1.0,\n", + " global_start_sigma=0.5,\n", + " w_global=1.0,\n", + " extend_order=False,\n", + " extend_radius=True,\n", + " clip_local=None,\n", + " clip_global=1000.0,\n", + " ) -> Union[MoleculeGNNOutput, Tuple]:\n", + " r\"\"\"\n", + " Args:\n", + " sample: packed torch geometric object\n", + " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n", + " return_dict (`bool`, *optional*, defaults to `True`):\n", + " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n", + " Returns:\n", + " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n", + " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n", + " \"\"\"\n", + "\n", + " # unpack sample\n", + " atom_type = sample.atom_type\n", + " bond_index = sample.edge_index\n", + " bond_type = sample.edge_type\n", + " num_graphs = sample.num_graphs\n", + " pos = sample.pos\n", + "\n", + " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n", + "\n", + " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n", + " atom_type=atom_type,\n", + " pos=sample.pos,\n", + " bond_index=bond_index,\n", + " bond_type=bond_type,\n", + " batch=sample.batch,\n", + " time_step=timesteps,\n", + " return_edges=True,\n", + " extend_order=extend_order,\n", + " extend_radius=extend_radius,\n", + " ) # (E_global, 1), (E_local, 1)\n", + "\n", + " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n", + " node_eq_local = graph_field_network(\n", + " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n", + " )\n", + " if clip_local is not None:\n", + " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n", + "\n", + " # Global\n", + " if sigma < global_start_sigma:\n", + " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n", + " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n", + " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n", + " else:\n", + " node_eq_global = 0\n", + "\n", + " # Sum\n", + " eps_pos = node_eq_local + node_eq_global * w_global\n", + "\n", + " if not return_dict:\n", + " return (-eps_pos,)\n", + "\n", + " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))" ], - "_ngl_original_stage_parameters": { - "ambientColor": 14540253, - "ambientIntensity": 0.2, - "backgroundColor": "white", - "cameraEyeSep": 0.3, - "cameraFov": 40, - "cameraType": "perspective", - "clipDist": 10, - "clipFar": 100, - "clipNear": 0, - "fogFar": 100, - "fogNear": 50, - "hoverTimeout": 0, - "impostor": true, - "lightColor": 14540253, - "lightIntensity": 1, - "mousePreset": "default", - "panSpeed": 1, - "quality": "medium", - "rotateSpeed": 2, - "sampleLevel": 0, - "tooltip": true, - "workerDefault": true, - "zoomSpeed": 1.2 + "metadata": { + "id": "MCeZA1qQXzoK" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CCIrPYSJj9wd" + }, + "source": [ + "### Load pretrained model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YdrAr6Ch--Ab" + }, + "source": [ + "#### Load a model\n", + "The model used is a design an\n", + "equivariant convolutional layer, named graph field network (GFN).\n", + "\n", + "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DyCo0nsqjbml", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 172, + "referenced_widgets": [ + "d90f304e9560472eacfbdd11e46765eb", + "1c6246f15b654f4daa11c9bcf997b78c", + "c2321b3bff6f490ca12040a20308f555", + "b7feb522161f4cf4b7cc7c1a078ff12d", + "e2d368556e494ae7ae4e2e992af2cd4f", + "bbef741e76ec41b7ab7187b487a383df", + "561f742d418d4721b0670cc8dd62e22c", + "872915dd1bb84f538c44e26badabafdd", + "d022575f1fa2446d891650897f187b4d", + "fdc393f3468c432aa0ada05e238a5436", + "2c9362906e4b40189f16d14aa9a348da", + "6010fc8daa7a44d5aec4b830ec2ebaa1", + "7e0bb1b8d65249d3974200686b193be2", + "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", + "6526646be5ed415c84d1245b040e629b", + "24d31fc3576e43dd9f8301d2ef3a37ab", + "2918bfaadc8d4b1a9832522c40dfefb8", + "a4bfdca35cc54dae8812720f1b276a08", + "e4901541199b45c6a18824627692fc39", + "f915cf874246446595206221e900b2fe", + "a9e388f22a9742aaaf538e22575c9433", + "42f6c3db29d7484ba6b4f73590abd2f4" + ] + }, + "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45" }, - "_ngl_repr_dict": { - "0": { - "0": { - "params": { - "aspectRatio": 1.5, - "assembly": "default", - "bondScale": 0.3, - "bondSpacing": 0.75, - "clipCenter": { - "x": 0, - "y": 0, - "z": 0 + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Downloading: 0%| | 0.00/3.27M [00:00] 124.78K 180KB/s in 0.7s \n", + "\n", + "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n", + "\n" + ] } - }, - "1": { - "0": { - "params": { - "aspectRatio": 1.5, - "assembly": "default", - "bondScale": 0.3, - "bondSpacing": 0.75, - "clipCenter": { - "x": 0, - "y": 0, - "z": 0 + ], + "source": [ + "import torch\n", + "import numpy as np\n", + "\n", + "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n", + "dataset = torch.load('/content/molecules.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QZcmy1EvKQRk" + }, + "source": [ + "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JVjz6iH_H6Eh", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")" + ] }, - "clipNear": 0, - "clipRadius": 0, - "colorMode": "hcl", - "colorReverse": false, - "colorScale": "", - "colorScheme": "element", - "colorValue": 9474192, - "cylinderOnly": false, - "defaultAssembly": "", - "depthWrite": true, - "diffuse": 16777215, - "diffuseInterior": false, - "disableImpostor": false, - "disablePicking": false, - "flatShaded": false, - "interiorColor": 2236962, - "interiorDarkening": 0, - "lazy": false, - "lineOnly": false, - "linewidth": 2, - "matrix": { - "elements": [ - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1 - ] + "metadata": {}, + "execution_count": 20 + } + ], + "source": [ + "dataset[0]" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Run the diffusion process" + ], + "metadata": { + "id": "vHNiZAUxNgoy" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jZ1KZrxKqENg" + }, + "source": [ + "#### Helper Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "s240tYueqKKf" + }, + "outputs": [], + "source": [ + "from torch_geometric.data import Data, Batch\n", + "from torch_scatter import scatter_add, scatter_mean\n", + "from tqdm import tqdm\n", + "import copy\n", + "import os\n", + "\n", + "def repeat_data(data: Data, num_repeat) -> Batch:\n", + " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n", + " return Batch.from_data_list(datas)\n", + "\n", + "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n", + " datas = batch.to_data_list()\n", + " new_data = []\n", + " for i in range(num_repeat):\n", + " new_data += copy.deepcopy(datas)\n", + " return Batch.from_data_list(new_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AMnQTk0eqT7Z" + }, + "source": [ + "#### Constants" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WYGkzqgzrHmF" + }, + "outputs": [], + "source": [ + "num_samples = 1 # solutions per molecule\n", + "num_molecules = 3\n", + "\n", + "DEVICE = 'cuda'\n", + "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n", + "# constants for inference\n", + "w_global = 0.5 #0,.3 for qm9\n", + "global_start_sigma = 0.5\n", + "eta = 1.0\n", + "clip_local = None\n", + "clip_pos = None\n", + "\n", + "# constands for data handling\n", + "save_traj = False\n", + "save_data = False\n", + "output_dir = '/content/'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-xD5bJ3SqM7t" + }, + "source": [ + "#### Generate samples!\n", + "Note that the 3d representation of a molecule is referred to as the **conformation**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x9xuLUNg26z1", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " after removing the cwd from sys.path.\n", + "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n" + ] + } + ], + "source": [ + "results = []\n", + "\n", + "# define sigmas\n", + "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n", + "sigmas = sigmas.to(DEVICE)\n", + "\n", + "for count, data in enumerate(tqdm(dataset)):\n", + " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n", + "\n", + " data_input = data.clone()\n", + " data_input['pos_ref'] = None\n", + " batch = repeat_data(data_input, num_samples).to(DEVICE)\n", + "\n", + " # initial configuration\n", + " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n", + "\n", + " # for logging animation of denoising\n", + " pos_traj = []\n", + " with torch.no_grad():\n", + "\n", + " # scale initial sample\n", + " pos = pos_init * sigmas[-1]\n", + " for t in scheduler.timesteps:\n", + " batch.pos = pos\n", + "\n", + " # generate geometry with model, then filter it\n", + " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n", + "\n", + " # Update\n", + " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n", + "\n", + " pos = reconstructed_pos\n", + "\n", + " if torch.isnan(pos).any():\n", + " print(\"NaN detected. Please restart.\")\n", + " raise FloatingPointError()\n", + "\n", + " # recenter graph of positions for next iteration\n", + " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n", + "\n", + " # optional clipping\n", + " if clip_pos is not None:\n", + " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n", + " pos_traj.append(pos.clone().cpu())\n", + "\n", + " pos_gen = pos.cpu()\n", + " if save_traj:\n", + " pos_gen_traj = pos_traj.cpu()\n", + " data.pos_gen = torch.stack(pos_gen_traj)\n", + " else:\n", + " data.pos_gen = pos_gen\n", + " results.append(data)\n", + "\n", + "\n", + "if save_data:\n", + " save_path = os.path.join(output_dir, 'samples_all.pkl')\n", + "\n", + " with open(save_path, 'wb') as f:\n", + " pickle.dump(results, f)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Render the results!" + ], + "metadata": { + "id": "fSApwSaZNndW" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d47Zxo2OKdgZ" + }, + "source": [ + "This function allows us to render 3d in colab." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e9Cd0kCAv9b8" + }, + "outputs": [], + "source": [ + "from google.colab import output\n", + "output.enable_custom_widget_manager()" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Helper functions" + ], + "metadata": { + "id": "RjaVuR15NqzF" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "28rBYa9NKhlz" + }, + "source": [ + "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LKdKdwxcyTQ6" + }, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "def set_rdmol_positions(rdkit_mol, pos):\n", + " \"\"\"\n", + " Args:\n", + " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", + " pos: (N_atoms, 3)\n", + " \"\"\"\n", + " mol = deepcopy(rdkit_mol)\n", + " set_rdmol_positions_(mol, pos)\n", + " return mol\n", + "\n", + "def set_rdmol_positions_(mol, pos):\n", + " \"\"\"\n", + " Args:\n", + " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", + " pos: (N_atoms, 3)\n", + " \"\"\"\n", + " for i in range(pos.shape[0]):\n", + " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n", + " return mol\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NuE10hcpKmzK" + }, + "source": [ + "Process the generated data to make it easy to view." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KieVE1vc0_Vs", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "collect 5 generated molecules in `mols`\n" + ] + } + ], + "source": [ + "# the model can generate multiple conformations per 2d geometry\n", + "num_gen = results[0]['pos_gen'].shape[0]\n", + "\n", + "# init storage objects\n", + "mols_gen = []\n", + "mols_orig = []\n", + "for to_process in results:\n", + "\n", + " # store the reference 3d position\n", + " to_process['pos_ref'] = to_process['pos_ref'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", + "\n", + " # store the generated 3d position\n", + " to_process['pos_gen'] = to_process['pos_gen'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", + "\n", + " # copy data to new object\n", + " new_mol = set_rdmol_positions(to_process.rdmol, to_process['pos_gen'][0])\n", + "\n", + " # append results\n", + " mols_gen.append(new_mol)\n", + " mols_orig.append(to_process.rdmol)\n", + "\n", + "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tin89JwMKp4v" + }, + "source": [ + "Import tools to visualize the 2d chemical diagram of the molecule." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yqV6gllSZn38" + }, + "outputs": [], + "source": [ + "from rdkit.Chem import AllChem\n", + "from rdkit import Chem\n", + "from rdkit.Chem.Draw import rdMolDraw2D as MD2\n", + "from IPython.display import SVG, display" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TFNKmGddVoOk" + }, + "source": [ + "Select molecule to visualize" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KzuwLlrrVaGc" + }, + "outputs": [], + "source": [ + "idx = 0\n", + "assert idx < len(results), \"selected molecule that was not generated\"" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Viewing" + ], + "metadata": { + "id": "hkb8w0_SNtU8" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I3R4QBQeKttN" + }, + "source": [ + "This 2D rendering is the equivalent of the **input to the model**!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gkQRWjraaKex", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 321 + }, + "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "image/svg+xml": "\n\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" }, - "metalness": 0, - "multipleBond": "off", - "opacity": 1, - "openEnded": true, - "quality": "high", - "radialSegments": 20, - "radiusData": {}, - "radiusScale": 2, - "radiusSize": 0.15, - "radiusType": "size", - "roughness": 0.4, - "sele": "", - "side": "double", - "sphereDetail": 2, - "useInteriorColor": true, - "visible": true, - "wireframe": false - }, - "type": "ball+stick" + "metadata": {} } - } + ], + "source": [ + "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n", + "molSize=(450,300)\n", + "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n", + "drawer.DrawMolecule(mc)\n", + "drawer.FinishDrawing()\n", + "svg = drawer.GetDrawingText()\n", + "display(SVG(svg.replace('svg:','')))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z4FDMYMxKw2I" }, - "_ngl_serialize": false, - "_ngl_version": "", - "_ngl_view_id": [ - "FB989FD1-5B9C-446B-8914-6B58AF85446D" + "source": [ + "Generate the 3d molecule!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aT1Bkb8YxJfV", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17, + "referenced_widgets": [ + "695ab5bbf30a4ab19df1f9f33469f314", + "eac6a8dcdc9d4335a2e51031793ead29" + ] + }, + "outputId": "b98870ae-049d-4386-b676-166e9526bda2" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "695ab5bbf30a4ab19df1f9f33469f314" + } + }, + "metadata": { + "application/vnd.jupyter.widget-view+json": { + "colab": { + "custom_widget_manager": { + "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" + } + } + } + } + } ], - "_player_dict": {}, - "_scene_position": {}, - "_scene_rotation": {}, - "_synced_model_ids": [], - "_synced_repr_model_ids": [], - "_view_count": null, - "_view_height": "", - "_view_module": "nglview-js-widgets", - "_view_module_version": "3.0.1", - "_view_name": "NGLView", - "_view_width": "", - "background": "white", - "frame": 0, - "gui_style": null, - "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f", - "max_frame": 0, - "n_components": 2, - "picked": {} - } - }, - "c2321b3bff6f490ca12040a20308f555": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd", - "max": 3271865, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d", - "value": 3271865 - } - }, - "c30e6c2f3e2a44dbbb3d63bd519acaa4": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "c6596896148b4a8a9c57963b67c7782f": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "d022575f1fa2446d891650897f187b4d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "d90f304e9560472eacfbdd11e46765eb": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c", - "IPY_MODEL_c2321b3bff6f490ca12040a20308f555", - "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d" + "source": [ + "from nglview import show_rdkit as show" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pxtq8I-I18C-", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 337, + "referenced_widgets": [ + "be446195da2b4ff2aec21ec5ff963a54", + "c6596896148b4a8a9c57963b67c7782f", + "2489b5e5648541fbbdceadb05632a050", + "01e0ba4e5da04914b4652b8d58565d7b", + "c30e6c2f3e2a44dbbb3d63bd519acaa4", + "f31c6e40e9b2466a9064a2669933ecd5", + "19308ccac642498ab8b58462e3f1b0bb", + "4a081cdc2ec3421ca79dd933b7e2b0c4", + "e5c0d75eb5e1447abd560c8f2c6017e1", + "5146907ef6764654ad7d598baebc8b58", + "144ec959b7604a2cabb5ca46ae5e5379", + "abce2a80e6304df3899109c6d6cac199", + "65195cb7a4134f4887e9dd19f3676462" + ] + }, + "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "NGLWidget()" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "be446195da2b4ff2aec21ec5ff963a54" + } + }, + "metadata": { + "application/vnd.jupyter.widget-view+json": { + "colab": { + "custom_widget_manager": { + "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" + } + } + } + } + } ], - "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f" - } - }, - "e2d368556e494ae7ae4e2e992af2cd4f": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "e4901541199b45c6a18824627692fc39": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "e5c0d75eb5e1447abd560c8f2c6017e1": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "PlayModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "PlayModel", - "_playing": false, - "_repeat": false, - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "PlayView", - "description": "", - "description_tooltip": null, - "disabled": false, - "interval": 100, - "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4", - "max": 0, - "min": 0, - "show_repeat": true, - "step": 1, - "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5", - "value": 0 - } - }, - "eac6a8dcdc9d4335a2e51031793ead29": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f31c6e40e9b2466a9064a2669933ecd5": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "f915cf874246446595206221e900b2fe": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "fdc393f3468c432aa0ada05e238a5436": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } + "source": [ + "# new molecule\n", + "show(mols_gen[idx])" + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "KJr4h2mwXeTo" + }, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "d90f304e9560472eacfbdd11e46765eb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c", + "IPY_MODEL_c2321b3bff6f490ca12040a20308f555", + "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d" + ], + "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f" + } + }, + "1c6246f15b654f4daa11c9bcf997b78c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df", + "placeholder": "​", + "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c", + "value": "Downloading: 100%" + } + }, + "c2321b3bff6f490ca12040a20308f555": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd", + "max": 3271865, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d", + "value": 3271865 + } + }, + "b7feb522161f4cf4b7cc7c1a078ff12d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436", + "placeholder": "​", + "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da", + "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]" + } + }, + "e2d368556e494ae7ae4e2e992af2cd4f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bbef741e76ec41b7ab7187b487a383df": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "561f742d418d4721b0670cc8dd62e22c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "872915dd1bb84f538c44e26badabafdd": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d022575f1fa2446d891650897f187b4d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "fdc393f3468c432aa0ada05e238a5436": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2c9362906e4b40189f16d14aa9a348da": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6010fc8daa7a44d5aec4b830ec2ebaa1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2", + "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", + "IPY_MODEL_6526646be5ed415c84d1245b040e629b" + ], + "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab" + } + }, + "7e0bb1b8d65249d3974200686b193be2": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8", + "placeholder": "​", + "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08", + "value": "Downloading: 100%" + } + }, + "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39", + "max": 401, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_f915cf874246446595206221e900b2fe", + "value": 401 + } + }, + "6526646be5ed415c84d1245b040e629b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433", + "placeholder": "​", + "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4", + "value": " 401/401 [00:00<00:00, 13.5kB/s]" + } + }, + "24d31fc3576e43dd9f8301d2ef3a37ab": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2918bfaadc8d4b1a9832522c40dfefb8": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a4bfdca35cc54dae8812720f1b276a08": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "e4901541199b45c6a18824627692fc39": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f915cf874246446595206221e900b2fe": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "a9e388f22a9742aaaf538e22575c9433": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "42f6c3db29d7484ba6b4f73590abd2f4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "695ab5bbf30a4ab19df1f9f33469f314": { + "model_module": "nglview-js-widgets", + "model_name": "ColormakerRegistryModel", + "model_module_version": "3.0.1", + "state": { + "_dom_classes": [], + "_model_module": "nglview-js-widgets", + "_model_module_version": "3.0.1", + "_model_name": "ColormakerRegistryModel", + "_msg_ar": [], + "_msg_q": [], + "_ready": false, + "_view_count": null, + "_view_module": "nglview-js-widgets", + "_view_module_version": "3.0.1", + "_view_name": "ColormakerRegistryView", + "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29" + } + }, + "eac6a8dcdc9d4335a2e51031793ead29": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "be446195da2b4ff2aec21ec5ff963a54": { + "model_module": "nglview-js-widgets", + "model_name": "NGLModel", + "model_module_version": "3.0.1", + "state": { + "_camera_orientation": [ + -15.519693580202304, + -14.065056548036177, + -23.53197484807691, + 0, + -23.357853515109753, + 20.94055073042662, + 2.888695042134944, + 0, + 14.352363398292777, + 18.870825741878015, + -20.744689572909344, + 0, + 0.2724999189376831, + 0.6940000057220459, + -0.3734999895095825, + 1 + ], + "_camera_str": "orthographic", + "_dom_classes": [], + "_gui_theme": null, + "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050", + "_igui": null, + "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b", + "_model_module": "nglview-js-widgets", + "_model_module_version": "3.0.1", + "_model_name": "NGLModel", + "_ngl_color_dict": {}, + "_ngl_coordinate_resource": {}, + "_ngl_full_stage_parameters": { + "impostor": true, + "quality": "medium", + "workerDefault": true, + "sampleLevel": 0, + "backgroundColor": "white", + "rotateSpeed": 2, + "zoomSpeed": 1.2, + "panSpeed": 1, + "clipNear": 0, + "clipFar": 100, + "clipDist": 10, + "fogNear": 50, + "fogFar": 100, + "cameraFov": 40, + "cameraEyeSep": 0.3, + "cameraType": "perspective", + "lightColor": 14540253, + "lightIntensity": 1, + "ambientColor": 14540253, + "ambientIntensity": 0.2, + "hoverTimeout": 0, + "tooltip": true, + "mousePreset": "default" + }, + "_ngl_msg_archive": [ + { + "target": "Stage", + "type": "call_method", + "methodName": "loadFile", + "reconstruc_color_scheme": false, + "args": [ + { + "type": "blob", + "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n", + "binary": false + } + ], + "kwargs": { + "defaultRepresentation": true, + "ext": "pdb" + } + } + ], + "_ngl_original_stage_parameters": { + "impostor": true, + "quality": "medium", + "workerDefault": true, + "sampleLevel": 0, + "backgroundColor": "white", + "rotateSpeed": 2, + "zoomSpeed": 1.2, + "panSpeed": 1, + "clipNear": 0, + "clipFar": 100, + "clipDist": 10, + "fogNear": 50, + "fogFar": 100, + "cameraFov": 40, + "cameraEyeSep": 0.3, + "cameraType": "perspective", + "lightColor": 14540253, + "lightIntensity": 1, + "ambientColor": 14540253, + "ambientIntensity": 0.2, + "hoverTimeout": 0, + "tooltip": true, + "mousePreset": "default" + }, + "_ngl_repr_dict": { + "0": { + "0": { + "type": "ball+stick", + "params": { + "lazy": false, + "visible": true, + "quality": "high", + "sphereDetail": 2, + "radialSegments": 20, + "openEnded": true, + "disableImpostor": false, + "aspectRatio": 1.5, + "lineOnly": false, + "cylinderOnly": false, + "multipleBond": "off", + "bondScale": 0.3, + "bondSpacing": 0.75, + "linewidth": 2, + "radiusType": "size", + "radiusData": {}, + "radiusSize": 0.15, + "radiusScale": 2, + "assembly": "default", + "defaultAssembly": "", + "clipNear": 0, + "clipRadius": 0, + "clipCenter": { + "x": 0, + "y": 0, + "z": 0 + }, + "flatShaded": false, + "opacity": 1, + "depthWrite": true, + "side": "double", + "wireframe": false, + "colorScheme": "element", + "colorScale": "", + "colorReverse": false, + "colorValue": 9474192, + "colorMode": "hcl", + "roughness": 0.4, + "metalness": 0, + "diffuse": 16777215, + "diffuseInterior": false, + "useInteriorColor": true, + "interiorColor": 2236962, + "interiorDarkening": 0, + "matrix": { + "elements": [ + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1 + ] + }, + "disablePicking": false, + "sele": "" + } + } + }, + "1": { + "0": { + "type": "ball+stick", + "params": { + "lazy": false, + "visible": true, + "quality": "high", + "sphereDetail": 2, + "radialSegments": 20, + "openEnded": true, + "disableImpostor": false, + "aspectRatio": 1.5, + "lineOnly": false, + "cylinderOnly": false, + "multipleBond": "off", + "bondScale": 0.3, + "bondSpacing": 0.75, + "linewidth": 2, + "radiusType": "size", + "radiusData": {}, + "radiusSize": 0.15, + "radiusScale": 2, + "assembly": "default", + "defaultAssembly": "", + "clipNear": 0, + "clipRadius": 0, + "clipCenter": { + "x": 0, + "y": 0, + "z": 0 + }, + "flatShaded": false, + "opacity": 1, + "depthWrite": true, + "side": "double", + "wireframe": false, + "colorScheme": "element", + "colorScale": "", + "colorReverse": false, + "colorValue": 9474192, + "colorMode": "hcl", + "roughness": 0.4, + "metalness": 0, + "diffuse": 16777215, + "diffuseInterior": false, + "useInteriorColor": true, + "interiorColor": 2236962, + "interiorDarkening": 0, + "matrix": { + "elements": [ + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1 + ] + }, + "disablePicking": false, + "sele": "" + } + } + } + }, + "_ngl_serialize": false, + "_ngl_version": "", + "_ngl_view_id": [ + "FB989FD1-5B9C-446B-8914-6B58AF85446D" + ], + "_player_dict": {}, + "_scene_position": {}, + "_scene_rotation": {}, + "_synced_model_ids": [], + "_synced_repr_model_ids": [], + "_view_count": null, + "_view_height": "", + "_view_module": "nglview-js-widgets", + "_view_module_version": "3.0.1", + "_view_name": "NGLView", + "_view_width": "", + "background": "white", + "frame": 0, + "gui_style": null, + "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f", + "max_frame": 0, + "n_components": 2, + "picked": {} + } + }, + "c6596896148b4a8a9c57963b67c7782f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2489b5e5648541fbbdceadb05632a050": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ButtonModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ButtonView", + "button_style": "", + "description": "", + "disabled": false, + "icon": "compress", + "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199", + "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462", + "tooltip": "" + } + }, + "01e0ba4e5da04914b4652b8d58565d7b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1", + "IPY_MODEL_5146907ef6764654ad7d598baebc8b58" + ], + "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379" + } + }, + "c30e6c2f3e2a44dbbb3d63bd519acaa4": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f31c6e40e9b2466a9064a2669933ecd5": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "19308ccac642498ab8b58462e3f1b0bb": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4a081cdc2ec3421ca79dd933b7e2b0c4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "SliderStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "SliderStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "", + "handle_color": null + } + }, + "e5c0d75eb5e1447abd560c8f2c6017e1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "PlayModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "PlayModel", + "_playing": false, + "_repeat": false, + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "PlayView", + "description": "", + "description_tooltip": null, + "disabled": false, + "interval": 100, + "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4", + "max": 0, + "min": 0, + "show_repeat": true, + "step": 1, + "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5", + "value": 0 + } + }, + "5146907ef6764654ad7d598baebc8b58": { + "model_module": "@jupyter-widgets/controls", + "model_name": "IntSliderModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "IntSliderModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "IntSliderView", + "continuous_update": true, + "description": "", + "description_tooltip": null, + "disabled": false, + "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb", + "max": 0, + "min": 0, + "orientation": "horizontal", + "readout": true, + "readout_format": "d", + "step": 1, + "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4", + "value": 0 + } + }, + "144ec959b7604a2cabb5ca46ae5e5379": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "abce2a80e6304df3899109c6d6cac199": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "34px" + } + }, + "65195cb7a4134f4887e9dd19f3676462": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ButtonStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "button_color": null, + "font_weight": "" + } + } + } } - } - } - }, - "nbformat": 4, - "nbformat_minor": 0 + }, + "nbformat": 4, + "nbformat_minor": 0 } \ No newline at end of file diff --git a/examples/research_projects/gligen/demo.ipynb b/examples/research_projects/gligen/demo.ipynb index 4930253ff66e..571f1a0323a2 100644 --- a/examples/research_projects/gligen/demo.ipynb +++ b/examples/research_projects/gligen/demo.ipynb @@ -26,7 +26,8 @@ "%load_ext autoreload\n", "%autoreload 2\n", "\n", - "from diffusers import StableDiffusionGLIGENPipeline" + "import torch\n", + "from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline" ] }, { @@ -35,17 +36,16 @@ "metadata": {}, "outputs": [], "source": [ - "from transformers import CLIPTextModel, CLIPTokenizer\n", - "\n", + "import os\n", "import diffusers\n", "from diffusers import (\n", " AutoencoderKL,\n", " DDPMScheduler,\n", - " EulerDiscreteScheduler,\n", " UNet2DConditionModel,\n", + " UniPCMultistepScheduler,\n", + " EulerDiscreteScheduler,\n", ")\n", - "\n", - "\n", + "from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n", "# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n", "\n", "pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n", @@ -122,7 +122,6 @@ "\n", "import numpy as np\n", "\n", - "\n", "boxes = np.array([x[1] for x in gen_boxes])\n", "boxes = boxes / 512\n", "boxes[:, 2] = boxes[:, 0] + boxes[:, 2]\n", From af0effaa90787705ea921c106765a32be9465b82 Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Thu, 19 Dec 2024 12:28:58 +0800 Subject: [PATCH 25/55] update OmniGen2DModel --- docs/source/en/using-diffusers/omnigen.md | 29 -- scripts/convert_omnigen_to_diffusers.py | 61 ++- src/diffusers/models/attention.py | 31 ++ src/diffusers/models/attention_processor.py | 76 ++++ src/diffusers/models/embeddings.py | 128 ++++--- .../transformers/transformer_omnigen.py | 350 ++++++++---------- .../pipelines/omnigen/kvcache_omnigen.py | 122 ------ .../pipelines/omnigen/pipeline_omnigen.py | 57 +-- .../omnigen/test_pipeline_omnigen.py | 271 ++++++-------- 9 files changed, 487 insertions(+), 638 deletions(-) delete mode 100644 src/diffusers/pipelines/omnigen/kvcache_omnigen.py diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md index 6742fef24b8d..45945090abbe 100644 --- a/docs/source/en/using-diffusers/omnigen.md +++ b/docs/source/en/using-diffusers/omnigen.md @@ -301,39 +301,10 @@ However, when using input images, the computational cost increases. Here are some guidelines to help you reduce computational costs when input multiple images. The experiments are conducted on A800 GPU and input two images to OmniGen. -### Inference speed - -| Parameter | Inference Time | -|--------------------------|----------------| -| use_kv_cache=True | 90s | -| use_kv_cache=False | 221s | -| max_input_image_size=1024| 90s | -| max_input_image_size=512 | 58s | - -- `use_kv_cache=True`: - `use_kv_cache` will store key and value states of the input conditions to compute attention without redundant computations. - The default value is True, and OmniGen will offload the kv cache to cpu default. - - `use_kv_cache=False`: the inference time is 3m21s. - - `use_kv_cache=True`: the inference time is 1m30s. - -- `max_input_image_size`: - the maximum size of input image, which will be used to crop the input image - - `max_input_image_size=1024`: the inference time is 1m30s. - - `max_input_image_size=512`: the inference time is 58s. - - - - - - -### Memory - - | Method | Memory Usage | |---------------------------------------------|--------------| | pipe.to("cuda") | 31GB | | pipe.enable_model_cpu_offload() | 28GB | -| pipe.enable_transformer_block_cpu_offload() | 25GB | | pipe.enable_sequential_cpu_offload() | 11GB | - `pipe.enable_model_cpu_offload()`: diff --git a/scripts/convert_omnigen_to_diffusers.py b/scripts/convert_omnigen_to_diffusers.py index cfa46c1afb0e..46593530f8ff 100644 --- a/scripts/convert_omnigen_to_diffusers.py +++ b/scripts/convert_omnigen_to_diffusers.py @@ -1,5 +1,6 @@ import argparse import os +os.environ['HF_HUB_CACHE'] = "/share/shitao/downloaded_models2" import torch from huggingface_hub import snapshot_download @@ -35,37 +36,34 @@ def main(args): "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias", "final_layer.linear.weight": "proj_out.weight", "final_layer.linear.bias": "proj_out.bias", + "time_token.mlp.0.weight": "time_token.linear_1.weight", + "time_token.mlp.0.bias": "time_token.linear_1.bias", + "time_token.mlp.2.weight": "time_token.linear_2.weight", + "time_token.mlp.2.bias": "time_token.linear_2.bias", + "t_embedder.mlp.0.weight": "t_embedder.linear_1.weight", + "t_embedder.mlp.0.bias": "t_embedder.linear_1.bias", + "t_embedder.mlp.2.weight": "t_embedder.linear_2.weight", + "t_embedder.mlp.2.bias": "t_embedder.linear_2.bias", + "llm.embed_tokens.weight": "embed_tokens.weight", + } converted_state_dict = {} for k, v in ckpt.items(): if k in mapping_dict: converted_state_dict[mapping_dict[k]] = v + elif "qkv" in k: + to_q, to_k, to_v = v.chunk(3) + converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_q.weight"] = to_q + converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_k.weight"] = to_k + converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_v.weight"] = to_v + elif "o_proj" in k: + converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_out.0.weight"] = v else: - converted_state_dict[k] = v - - # transformer_config = AutoConfig.from_pretrained(args.origin_ckpt_path) - # print(type(transformer_config.__dict__)) - # print(transformer_config.__dict__) - - transformer_config = { - "_name_or_path": "Phi-3-vision-128k-instruct", - "architectures": ["Phi3ForCausalLM"], - "attention_dropout": 0.0, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 3072, - "initializer_range": 0.02, - "intermediate_size": 8192, - "max_position_embeddings": 131072, - "model_type": "phi3", - "num_attention_heads": 32, - "num_hidden_layers": 32, - "num_key_value_heads": 32, - "original_max_position_embeddings": 4096, - "rms_norm_eps": 1e-05, - "rope_scaling": { + converted_state_dict[k[4:]] = v + + transformer = OmniGenTransformer2DModel( + rope_scaling = { "long_factor": [ 1.0299999713897705, 1.0499999523162842, @@ -168,17 +166,6 @@ def main(args): ], "type": "su", }, - "rope_theta": 10000.0, - "sliding_window": 131072, - "tie_word_embeddings": False, - "torch_dtype": "bfloat16", - "transformers_version": "4.38.1", - "use_cache": True, - "vocab_size": 32064, - "_attn_implementation": "sdpa", - } - transformer = OmniGenTransformer2DModel( - transformer_config=transformer_config, patch_size=2, in_channels=4, pos_embed_max_size=192, @@ -189,7 +176,7 @@ def main(args): num_model_params = sum(p.numel() for p in transformer.parameters()) print(f"Total number of transformer parameters: {num_model_params}") - scheduler = FlowMatchEulerDiscreteScheduler() + scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1) vae = AutoencoderKL.from_pretrained(os.path.join(args.origin_ckpt_path, "vae"), torch_dtype=torch.float32) @@ -211,7 +198,7 @@ def main(args): ) parser.add_argument( - "--dump_path", default="OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline." + "--dump_path", default="/share/shitao/repos/OmniGen-v1-diffusers2", type=str, required=False, help="Path to the output pipeline." ) args = parser.parse_args() diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 02ed1f965abf..780f2e22b391 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -632,6 +632,37 @@ def forward(self, x): return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x)) +class OmniGenFeedForward(nn.Module): + r""" + A feed-forward layer for OmniGen. + + Parameters: + hidden_size (`int`): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + intermediate_size (`int`): The intermediate dimension of the feedforward layer. + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + ): + super().__init__() + self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + + self.activation_fn = nn.SiLU() + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + @maybe_allow_in_graph class TemporalBasicTransformerBlock(nn.Module): r""" diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index faacc431c386..73c068a177d1 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3414,6 +3414,82 @@ def __call__( return hidden_states + +class OmniGenAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the OmniGen model. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[torch.Tensor] = None, + key_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + bsz, q_len, query_dim = query.size() + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply RoPE if needed + if query_rotary_emb is not None: + query = apply_rotary_emb(query, query_rotary_emb, revert_x_as_rotated=True) + if key_rotary_emb is not None: + key = apply_rotary_emb(key, key_rotary_emb, revert_x_as_rotated=True) + + query, key = query.to(dtype), key.to(dtype) + + + # perform Grouped-qurey Attention (GQA) + n_rep = attn.heads // kv_heads + if n_rep > 1: + key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask + ) + hidden_states = hidden_states.transpose(1, 2).to(dtype) + hidden_states = hidden_states.reshape(bsz, q_len, attn.out_dim) + hidden_states = attn.to_out[0](hidden_states) + return hidden_states + + class PAGHunyuanAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index a9cb2a8ac865..d5fff5c055ae 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -990,6 +990,7 @@ def apply_rotary_emb( freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], use_real: bool = True, use_real_unbind_dim: int = -1, + revert_x_as_rotated: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings @@ -1007,21 +1008,33 @@ def apply_rotary_emb( """ if use_real: cos, sin = freqs_cis # [S, D] - cos = cos[None, None] - sin = sin[None, None] + if len(cos.shape) == 2: + cos = cos[None, None] + sin = sin[None, None] + elif len(cos.shape) == 3: + # Used for OmniGen + cos = cos[:, :, None, :,] + sin = sin[:, :, None, :,] cos, sin = cos.to(x.device), sin.to(x.device) - if use_real_unbind_dim == -1: - # Used for flux, cogvideox, hunyuan-dit - x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] - x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) - elif use_real_unbind_dim == -2: - # Used for Stable Audio - x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] - x_rotated = torch.cat([-x_imag, x_real], dim=-1) + if revert_x_as_rotated: + # Rotates half the hidden dims of the input. this rorate function is widely used in LLM, e.g. Llama, Phi3, etc. + # Used for OmniGen + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + x_rotated = torch.cat((-x2, x1), dim=-1) else: - raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") - + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out @@ -1034,6 +1047,8 @@ def apply_rotary_emb( return x_out.type_as(x) + + def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions): # TODO(aryan): rewrite def apply_1d_rope(tokens, pos, cos, sin): @@ -1082,6 +1097,58 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: return freqs_cos, freqs_sin +class OmniGenSuScaledRotaryEmbedding(nn.Module): + def __init__(self, + dim, + max_position_embeddings=131072, + original_max_position_embeddings=4096, + base=10000, + rope_scaling=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + self.short_factor = rope_scaling["short_factor"] + self.long_factor = rope_scaling["long_factor"] + self.original_max_position_embeddings = original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + class TimestepEmbedding(nn.Module): def __init__( self, @@ -1149,43 +1216,6 @@ def forward(self, timesteps): return t_emb -class OmniGenTimestepEmbed(nn.Module): - """ - Embeds scalar timesteps into vector representations for OmniGen - """ - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. - :return: an (N, D) Tensor of positional embeddings. - """ - # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=t.device - ) - args = t[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t, dtype=torch.float32): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) - t_emb = self.mlp(t_freq) - return t_emb class GaussianFourierProjection(nn.Module): diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index 209d6b27aaa6..4b1fca65a9cd 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -12,205 +12,101 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from transformers import Phi3Config, Phi3Model -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_outputs import BaseModelOutputWithPast from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers -from ..attention_processor import AttentionProcessor -from ..embeddings import OmniGenPatchEmbed, OmniGenTimestepEmbed +from ..attention import OmniGenFeedForward +from ..attention_processor import Attention, OmniGenAttnProcessor2_0 +from ..embeddings import OmniGenPatchEmbed, TimestepEmbedding, Timesteps, OmniGenSuScaledRotaryEmbedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNorm +from ..normalization import AdaLayerNorm, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name -@dataclass -class OmniGen2DModelOutput(Transformer2DModelOutput): +class OmniGenBlock(nn.Module): """ - The output of [`Transformer2DModel`]. - - Args: - sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): - The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability - distributions for the unnoised latent pixels. - past_key_values (`transformers.cache_utils.DynamicCache`) - """ - - sample: "torch.Tensor" # noqa: F821 - past_key_values: "DynamicCache" - - -class OmniGenBaseTransformer(Phi3Model): - """ - Transformer used in OmniGen. The transformer block is from Ph3, and only modify the attention mask. References: - [OmniGen](https://arxiv.org/pdf/2409.11340) + A LuminaNextDiTBlock for LuminaNextDiT2DModel. Parameters: - config: Phi3Config + hidden_size (`int`): Embedding dimension of the input features. + num_attention_heads (`int`): Number of attention heads. + num_key_value_heads (`int`): + Number of attention heads in key and value features (if using GQA), or set to None for the same as query. + intermediate_size (`int`): size of intermediate layer. + rms_norm_eps (`float`): The eps for norm layer. """ - def prefetch_layer(self, layer_idx: int, device: torch.device): - "Starts prefetching the next layer cache" - with torch.cuda.stream(self.prefetch_stream): - # Prefetch next layer tensors to GPU - for name, param in self.layers[layer_idx].named_parameters(): - param.data = param.data.to(device, non_blocking=True) - - def evict_previous_layer(self, layer_idx: int): - "Moves the previous layer cache to the CPU" - prev_layer_idx = layer_idx - 1 - for name, param in self.layers[prev_layer_idx].named_parameters(): - param.data = param.data.to("cpu", non_blocking=True) - - def get_offload_layer(self, layer_idx: int, device: torch.device): - # init stream - if not hasattr(self, "prefetch_stream"): - self.prefetch_stream = torch.cuda.Stream() - - # delete previous layer - torch.cuda.current_stream().synchronize() - self.evict_previous_layer(layer_idx) + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + intermediate_size: int, + rms_norm_eps: float, + ) -> None: + super().__init__() - # make sure the current layer is ready - torch.cuda.synchronize(self.prefetch_stream) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.self_attn = Attention( + query_dim=hidden_size, + cross_attention_dim=hidden_size, + dim_head=hidden_size // num_attention_heads, + heads=num_attention_heads, + kv_heads=num_key_value_heads, + bias=False, + out_dim=hidden_size, + out_bias=False, + processor=OmniGenAttnProcessor2_0(), + ) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.mlp = OmniGenFeedForward(hidden_size, intermediate_size) - # load next layer - self.prefetch_layer((layer_idx + 1) % len(self.layers), device) def forward( self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - offload_transformer_block: Optional[bool] = False, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) - - if attention_mask is not None and attention_mask.dim() == 3: - dtype = inputs_embeds.dtype - min_dtype = torch.finfo(dtype).min - attention_mask = (1 - attention_mask) * min_dtype - attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype) - else: - raise Exception("attention_mask parameter was unavailable or invalid") - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - layer_idx = -1 - for decoder_layer in self.layers: - layer_idx += 1 - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - if offload_transformer_block and not self.training: - if not torch.cuda.is_available(): - logger.warning_once( - "We don't detecte any available GPU, so diable `offload_transformer_block`" - ) - else: - self.get_offload_layer(layer_idx, device=inputs_embeds.device) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + rotary_emb: torch.Tensor, + ): + """ + Perform a forward pass through the LuminaNextDiTBlock. - hidden_states = layer_outputs[0] + Parameters: + hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock. + attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask. + rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies. + """ + residual = hidden_states - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + hidden_states = self.input_layernorm(hidden_states) - if output_attentions: - all_self_attns += (layer_outputs[1],) + # Self Attention + attn_outputs = self.self_attn( + hidden_states=hidden_states, + encoder_hidden_states=hidden_states, + attention_mask=attention_mask, + query_rotary_emb=rotary_emb, + key_rotary_emb=rotary_emb, + ) - hidden_states = self.norm(hidden_states) + hidden_states = residual + attn_outputs - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() + return hidden_states - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): @@ -220,7 +116,23 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): Reference: https://arxiv.org/pdf/2409.11340 Parameters: - transformer_config (`dict`): config for transformer layers. OmniGen-v1 use Phi3 as transformer backbone + hidden_size (`int`, *optional*, defaults to 3072): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + rms_norm_eps (`float`, *optional*, defaults to 1e-5): eps for RMSNorm layer. + num_attention_heads (`int`, *optional*, defaults to 32): + The number of attention heads in each attention layer. This parameter specifies how many separate attention + mechanisms are used. + num_kv_heads (`int`, *optional*, defaults to 32): + The number of key-value heads in the attention mechanism, if different from the number of attention heads. + If None, it defaults to num_attention_heads. + intermediate_size (`int`, *optional*, defaults to 8192): dimension of the intermediate layer in FFN + num_layers (`int`, *optional*, default to 32): + The number of layers in the model. This defines the depth of the neural network. + pad_token_id (`int`, *optional*, default to 32000): + id for pad token + vocab_size (`int`, *optional*, default to 32064): + size of vocabulary patch_size (`int`, defaults to 2): Patch size to turn the input data into small patches. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input. pos_embed_max_size (`int`, *optional*, defaults to 192): The max size of pos emb. @@ -231,10 +143,26 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): @register_to_config def __init__( self, - transformer_config: Dict, + hidden_size: int = 3072, + rms_norm_eps: float = 1e-05, + num_attention_heads: int = 32, + num_key_value_heads: int = 32, + intermediate_size: int = 8192, + num_layers: int = 32, + pad_token_id: int = 32000, + vocab_size: int = 32064, + max_position_embeddings: int = 131072, + original_max_position_embeddings: int = 4096, + rope_base: int = 10000, + rope_scaling: Dict = None, patch_size=2, in_channels=4, pos_embed_max_size: int = 192, + time_step_dim: int = 256, + flip_sin_to_cos: bool = True, + downscale_freq_shift: int = 0, + timestep_activation_fn: str = 'silu', + ): super().__init__() self.in_channels = in_channels @@ -242,9 +170,6 @@ def __init__( self.patch_size = patch_size self.pos_embed_max_size = pos_embed_max_size - transformer_config = Phi3Config(**transformer_config) - hidden_size = transformer_config.hidden_size - self.patch_embedding = OmniGenPatchEmbed( patch_size=patch_size, in_channels=in_channels, @@ -252,14 +177,33 @@ def __init__( pos_embed_max_size=pos_embed_max_size, ) - self.time_token = OmniGenTimestepEmbed(hidden_size) - self.t_embedder = OmniGenTimestepEmbed(hidden_size) + self.time_proj = Timesteps(time_step_dim, flip_sin_to_cos, downscale_freq_shift) + self.time_token = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn) + self.t_embedder = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn) self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1) self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True) - self.llm = OmniGenBaseTransformer(config=transformer_config) - self.llm.config.use_cache = False + self.embed_tokens = nn.Embedding(vocab_size, hidden_size, pad_token_id) + self.rotary_emb = OmniGenSuScaledRotaryEmbedding(hidden_size // num_attention_heads, + max_position_embeddings=max_position_embeddings, + original_max_position_embeddings=original_max_position_embeddings, + base=rope_base, + rope_scaling=rope_scaling) + + self.layers = nn.ModuleList( + [ + OmniGenBlock( + hidden_size, + num_attention_heads, + num_key_value_heads, + intermediate_size, + rms_norm_eps, + ) + for _ in range(num_layers) + ] + ) + self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) def unpatchify(self, x, h, w): """ @@ -276,7 +220,7 @@ def unpatchify(self, x, h, w): @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: + def attn_processors(self) -> Dict[str, OmniGenAttnProcessor2_0]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with @@ -300,7 +244,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + def set_attn_processor(self, processor: Union[OmniGenAttnProcessor2_0, Dict[str, OmniGenAttnProcessor2_0]]): r""" Sets the attention processor to use to compute attention. @@ -358,7 +302,7 @@ def get_multimodal_embeddings( input_img_latents = [x.to(self.dtype) for x in input_img_latents] condition_tokens = None if input_ids is not None: - condition_tokens = self.llm.embed_tokens(input_ids) + condition_tokens = self.embed_tokens(input_ids) input_img_inx = 0 if input_img_latents is not None: input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True) @@ -382,8 +326,6 @@ def forward( input_image_sizes: Dict[int, List[int]], attention_mask: torch.Tensor, position_ids: torch.Tensor, - past_key_values: DynamicCache = None, - offload_transformer_block: bool = False, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ): @@ -439,8 +381,8 @@ def forward( height, width = hidden_states.size()[-2:] hidden_states = self.patch_embedding(hidden_states, is_input_image=False) num_tokens_for_output_image = hidden_states.size(1) - - time_token = self.time_token(timestep, dtype=hidden_states.dtype).unsqueeze(1) + + time_token = self.time_token(self.time_proj(timestep).to(hidden_states.dtype)).unsqueeze(1) condition_tokens = self.get_multimodal_embeddings( input_ids=input_ids, @@ -448,23 +390,39 @@ def forward( input_image_sizes=input_image_sizes, ) if condition_tokens is not None: - input_emb = torch.cat([condition_tokens, time_token, hidden_states], dim=1) + inputs_embeds = torch.cat([condition_tokens, time_token, hidden_states], dim=1) else: - input_emb = torch.cat([time_token, hidden_states], dim=1) - output = self.llm( - inputs_embeds=input_emb, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - offload_transformer_block=offload_transformer_block, - ) - output, past_key_values = output.last_hidden_state, output.past_key_values + inputs_embeds = torch.cat([time_token, hidden_states], dim=1) + + + batch_size, seq_length = inputs_embeds.shape[:2] + position_ids = position_ids.view(-1, seq_length).long() - image_embedding = output[:, -num_tokens_for_output_image:] - time_emb = self.t_embedder(timestep, dtype=hidden_states.dtype) + if attention_mask is not None and attention_mask.dim() == 3: + dtype = inputs_embeds.dtype + min_dtype = torch.finfo(dtype).min + attention_mask = (1 - attention_mask) * min_dtype + attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype) + else: + raise Exception("attention_mask parameter was unavailable or invalid") + + hidden_states = inputs_embeds + + cos, sin = self.rotary_emb(hidden_states, position_ids) + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + attention_mask=attention_mask, + rotary_emb=[cos, sin] + ) + + hidden_states = self.norm(hidden_states) + + image_embedding = hidden_states[:, -num_tokens_for_output_image:] + time_emb = self.t_embedder(self.time_proj(timestep).to(hidden_states.dtype)) x = self.proj_out(self.norm_out(image_embedding, temb=time_emb)) output = self.unpatchify(x, height, width) if not return_dict: - return (output, past_key_values) - return OmniGen2DModelOutput(sample=output, past_key_values=past_key_values) + return (output, ) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/omnigen/kvcache_omnigen.py b/src/diffusers/pipelines/omnigen/kvcache_omnigen.py deleted file mode 100644 index ef2ca19e4455..000000000000 --- a/src/diffusers/pipelines/omnigen/kvcache_omnigen.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, List, Optional, Tuple - -import torch -from transformers.cache_utils import DynamicCache - - -class OmniGenCache(DynamicCache): - def __init__(self, num_tokens_for_img: int, offload_kv_cache: bool = False) -> None: - if not torch.cuda.is_available(): - # print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!") - # offload_kv_cache = False - raise RuntimeError( - "OffloadedCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!" - ) - super().__init__() - self.original_device = [] - self.prefetch_stream = torch.cuda.Stream() - self.num_tokens_for_img = num_tokens_for_img - self.offload_kv_cache = offload_kv_cache - - def prefetch_layer(self, layer_idx: int): - "Starts prefetching the next layer cache" - if layer_idx < len(self): - with torch.cuda.stream(self.prefetch_stream): - # Prefetch next layer tensors to GPU - device = self.original_device[layer_idx] - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) - self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True) - - def evict_previous_layer(self, layer_idx: int): - "Moves the previous layer cache to the CPU" - if len(self) > 2: - # We do it on the default stream so it occurs after all earlier computations on these tensors are done - if layer_idx == 0: - prev_layer_idx = -1 - else: - prev_layer_idx = (layer_idx - 1) % len(self) - self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True) - self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True) - - def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: - "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." - if layer_idx < len(self): - if self.offload_kv_cache: - # Evict the previous layer if necessary - torch.cuda.current_stream().synchronize() - self.evict_previous_layer(layer_idx) - # Load current layer cache to its original device if not already there - # self.prefetch_stream.synchronize(original_device) - torch.cuda.synchronize(self.prefetch_stream) - key_tensor = self.key_cache[layer_idx] - value_tensor = self.value_cache[layer_idx] - - # Prefetch the next layer - self.prefetch_layer((layer_idx + 1) % len(self)) - else: - key_tensor = self.key_cache[layer_idx] - value_tensor = self.value_cache[layer_idx] - return (key_tensor, value_tensor) - else: - raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`. - Return: - A tuple containing the updated key and value states. - """ - # Update the cache - if len(self.key_cache) < layer_idx: - raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.") - elif len(self.key_cache) == layer_idx: - # only cache the states for condition tokens - key_states = key_states[..., : -(self.num_tokens_for_img + 1), :] - value_states = value_states[..., : -(self.num_tokens_for_img + 1), :] - - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - self.key_cache.append(key_states) - self.value_cache.append(value_states) - self.original_device.append(key_states.device) - if self.offload_kv_cache: - self.evict_previous_layer(layer_idx) - return self.key_cache[layer_idx], self.value_cache[layer_idx] - else: - # only cache the states for condition tokens - key_tensor, value_tensor = self[layer_idx] - k = torch.cat([key_tensor, key_states], dim=-2) - v = torch.cat([value_tensor, value_states], dim=-2) - return k, v diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index f7f8827b2b7c..1e4ef5304f87 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -32,7 +32,6 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from .kvcache_omnigen import OmniGenCache from .processor_omnigen import OmniGenMultiModalProcessor @@ -203,8 +202,6 @@ def check_inputs( input_images, height, width, - use_kv_cache, - offload_kv_cache, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): @@ -225,12 +222,6 @@ def check_inputs( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) - if use_kv_cache and offload_kv_cache: - if not torch.cuda.is_available(): - raise ValueError( - "Don't fine avaliable GPUs. `offload_kv_cache` can't be used when there is no GPU. please set it to False: `use_kv_cache=False, offload_kv_cache=False`" - ) - if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -324,17 +315,6 @@ def num_timesteps(self): def interrupt(self): return self._interrupt - def enable_transformer_block_cpu_offload(self, device: Union[torch.device, str] = "cuda"): - torch_device = torch.device(device) - for name, param in self.transformer.named_parameters(): - if "layers" in name and "layers.0" not in name: - param.data = param.data.cpu() - else: - param.data = param.data.to(torch_device) - for buffer_name, buffer in self.transformer.patch_embedding.named_buffers(): - setattr(self.transformer.patch_embedding, buffer_name, buffer.to(torch_device)) - self.vae.to(torch_device) - self.offload_transformer_block = True @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -351,9 +331,6 @@ def __call__( timesteps: List[int] = None, guidance_scale: float = 2.5, img_guidance_scale: float = 1.6, - use_kv_cache: bool = True, - offload_kv_cache: bool = True, - offload_transformer_block: bool = False, use_input_image_size_as_output: bool = False, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -395,12 +372,6 @@ def __call__( usually at the expense of lower image quality. img_guidance_scale (`float`, *optional*, defaults to 1.6): Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800). - use_kv_cache (`bool`, *optional*, defaults to True): - enable kv cache to speed up the inference - offload_kv_cache (`bool`, *optional*, defaults to True): - offload the cached key and value to cpu, which can save memory but slow down the generation silightly - offload_transformer_block (`bool`, *optional*, defaults to False): - offload the transformer layers to cpu, which can save memory but slow down the generation use_input_image_size_as_output (bool, defaults to False): whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task @@ -451,17 +422,12 @@ def __call__( # using Float32 for the VAE doesn't take up much memory but can prevent potential black image outputs. self.vae.to(torch.float32) - if offload_transformer_block: - self.enable_transformer_block_cpu_offload() - # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, input_images, height, width, - use_kv_cache, - offload_kv_cache, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) @@ -513,10 +479,6 @@ def __call__( latents, ) - # 7. Prepare OmniGenCache - num_tokens_for_output_img = latents.size(-1) * latents.size(-2) // (self.transformer.patch_size**2) - cache = OmniGenCache(num_tokens_for_output_img, offload_kv_cache) if use_kv_cache else None - self.transformer.llm.config.use_cache = use_kv_cache # 8. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -527,7 +489,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - noise_pred, cache = self.transformer( + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, input_ids=processed_data["input_ids"], @@ -536,23 +498,8 @@ def __call__( attention_mask=processed_data["attention_mask"], position_ids=processed_data["position_ids"], attention_kwargs=attention_kwargs, - past_key_values=cache, - offload_transformer_block=self.offload_transformer_block - if hasattr(self, "offload_transformer_block") - else offload_transformer_block, return_dict=False, - ) - - # if use kv cache, don't need attention mask and position ids of condition tokens for next step - if use_kv_cache: - if processed_data["input_ids"] is not None: - processed_data["input_ids"] = None - processed_data["attention_mask"] = processed_data["attention_mask"][ - ..., -(num_tokens_for_output_img + 1) :, : - ] # +1 is for the timestep token - processed_data["position_ids"] = processed_data["position_ids"][ - :, -(num_tokens_for_output_img + 1) : - ] + )[0] if num_cfg == 2: cond, uncond, img_cond = torch.split(noise_pred, len(noise_pred) // 3, dim=0) diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py index 3edaf9cf3110..a542a8432aeb 100644 --- a/tests/pipelines/omnigen/test_pipeline_omnigen.py +++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py @@ -33,141 +33,114 @@ class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin): def get_dummy_components(self): torch.manual_seed(0) - transformer_config = { - "_name_or_path": "Phi-3-vision-128k-instruct", - "architectures": ["Phi3ForCausalLM"], - "attention_dropout": 0.0, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 3072, - "initializer_range": 0.02, - "intermediate_size": 8192, - "max_position_embeddings": 131072, - "model_type": "phi3", - "num_attention_heads": 32, - "num_hidden_layers": 32, - "num_key_value_heads": 32, - "original_max_position_embeddings": 4096, - "rms_norm_eps": 1e-05, - "rope_scaling": { - "long_factor": [ - 1.0299999713897705, - 1.0499999523162842, - 1.0499999523162842, - 1.0799999237060547, - 1.2299998998641968, - 1.2299998998641968, - 1.2999999523162842, - 1.4499999284744263, - 1.5999999046325684, - 1.6499998569488525, - 1.8999998569488525, - 2.859999895095825, - 3.68999981880188, - 5.419999599456787, - 5.489999771118164, - 5.489999771118164, - 9.09000015258789, - 11.579999923706055, - 15.65999984741211, - 15.769999504089355, - 15.789999961853027, - 18.360000610351562, - 21.989999771118164, - 23.079999923706055, - 30.009998321533203, - 32.35000228881836, - 32.590003967285156, - 35.56000518798828, - 39.95000457763672, - 53.840003967285156, - 56.20000457763672, - 57.95000457763672, - 59.29000473022461, - 59.77000427246094, - 59.920005798339844, - 61.190006256103516, - 61.96000671386719, - 62.50000762939453, - 63.3700065612793, - 63.48000717163086, - 63.48000717163086, - 63.66000747680664, - 63.850006103515625, - 64.08000946044922, - 64.760009765625, - 64.80001068115234, - 64.81001281738281, - 64.81001281738281, - ], - "short_factor": [ - 1.05, - 1.05, - 1.05, - 1.1, - 1.1, - 1.1, - 1.2500000000000002, - 1.2500000000000002, - 1.4000000000000004, - 1.4500000000000004, - 1.5500000000000005, - 1.8500000000000008, - 1.9000000000000008, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.1000000000000005, - 2.1000000000000005, - 2.2, - 2.3499999999999996, - 2.3499999999999996, - 2.3499999999999996, - 2.3499999999999996, - 2.3999999999999995, - 2.3999999999999995, - 2.6499999999999986, - 2.6999999999999984, - 2.8999999999999977, - 2.9499999999999975, - 3.049999999999997, - 3.049999999999997, - 3.049999999999997, - ], - "type": "su", - }, - "rope_theta": 10000.0, - "sliding_window": 131072, - "tie_word_embeddings": False, - "torch_dtype": "bfloat16", - "transformers_version": "4.38.1", - "use_cache": True, - "vocab_size": 32064, - "_attn_implementation": "sdpa", - } transformer = OmniGenTransformer2DModel( - transformer_config=transformer_config, - patch_size=2, - in_channels=4, - pos_embed_max_size=192, - ) + rope_scaling = { + "long_factor": [ + 1.0299999713897705, + 1.0499999523162842, + 1.0499999523162842, + 1.0799999237060547, + 1.2299998998641968, + 1.2299998998641968, + 1.2999999523162842, + 1.4499999284744263, + 1.5999999046325684, + 1.6499998569488525, + 1.8999998569488525, + 2.859999895095825, + 3.68999981880188, + 5.419999599456787, + 5.489999771118164, + 5.489999771118164, + 9.09000015258789, + 11.579999923706055, + 15.65999984741211, + 15.769999504089355, + 15.789999961853027, + 18.360000610351562, + 21.989999771118164, + 23.079999923706055, + 30.009998321533203, + 32.35000228881836, + 32.590003967285156, + 35.56000518798828, + 39.95000457763672, + 53.840003967285156, + 56.20000457763672, + 57.95000457763672, + 59.29000473022461, + 59.77000427246094, + 59.920005798339844, + 61.190006256103516, + 61.96000671386719, + 62.50000762939453, + 63.3700065612793, + 63.48000717163086, + 63.48000717163086, + 63.66000747680664, + 63.850006103515625, + 64.08000946044922, + 64.760009765625, + 64.80001068115234, + 64.81001281738281, + 64.81001281738281, + ], + "short_factor": [ + 1.05, + 1.05, + 1.05, + 1.1, + 1.1, + 1.1, + 1.2500000000000002, + 1.2500000000000002, + 1.4000000000000004, + 1.4500000000000004, + 1.5500000000000005, + 1.8500000000000008, + 1.9000000000000008, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.1000000000000005, + 2.1000000000000005, + 2.2, + 2.3499999999999996, + 2.3499999999999996, + 2.3499999999999996, + 2.3499999999999996, + 2.3999999999999995, + 2.3999999999999995, + 2.6499999999999986, + 2.6999999999999984, + 2.8999999999999977, + 2.9499999999999975, + 3.049999999999997, + 3.049999999999997, + 3.049999999999997, + ], + "type": "su", + }, + patch_size=2, + in_channels=4, + pos_embed_max_size=192, + ) torch.manual_seed(0) vae = AutoencoderKL( @@ -181,7 +154,7 @@ def get_dummy_components(self): up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], ) - scheduler = FlowMatchEulerDiscreteScheduler() + scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") # tokenizer = AutoTokenizer.from_pretrained("Shitao/OmniGen-v1") @@ -207,8 +180,6 @@ def get_dummy_inputs(self, device, seed=0): "output_type": "np", "height": 16, "width": 16, - "use_kv_cache": False, - "offload_kv_cache": False, } return inputs @@ -225,7 +196,7 @@ def test_inference(self): @require_torch_gpu class OmniGenPipelineSlowTests(unittest.TestCase): pipeline_class = OmniGenPipeline - repo_id = "Shitao/OmniGen-v1-diffusers" + repo_id = "shitao/OmniGen-v1-diffusers" def setUp(self): super().setUp() @@ -259,21 +230,21 @@ def test_omnigen_inference(self): image = pipe(**inputs).images[0] image_slice = image[0, :10, :10] - + print(image_slice) expected_slice = np.array( [ - [0.25806782, 0.28012177, 0.27807158], - [0.25740036, 0.2677201, 0.26857468], - [0.258638, 0.27035138, 0.26633185], - [0.2541029, 0.2636156, 0.26373306], - [0.24975497, 0.2608987, 0.2617477], - [0.25102, 0.26192215, 0.262023], - [0.24452701, 0.25664824, 0.259144], - [0.2419573, 0.2574909, 0.25996095], - [0.23953134, 0.25292695, 0.25652167], - [0.23523712, 0.24710432, 0.25460982], + [0.1783447, 0.16772744, 0.14339337], + [0.17066911, 0.15521264, 0.13757327], + [0.17072496, 0.15531206, 0.13524258], + [0.16746324, 0.1564025, 0.13794944], + [0.16490817, 0.15258026, 0.13697758], + [0.16971767, 0.15826806, 0.13928896], + [0.16782972, 0.15547255, 0.13783783], + [0.16464645, 0.15281534, 0.13522372], + [0.16535294, 0.15301755, 0.13526791], + [0.16365296, 0.15092957, 0.13443318], ], - dtype=np.float32, + dtype = np.float32, ) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) From a0cd392801bf08450d9fb87ad2ac9e38165998be Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Thu, 19 Dec 2024 12:37:24 +0800 Subject: [PATCH 26/55] make style --- scripts/convert_omnigen_to_diffusers.py | 6 +- src/diffusers/models/attention_processor.py | 6 +- src/diffusers/models/embeddings.py | 30 ++- .../transformers/transformer_omnigen.py | 40 ++- .../pipelines/omnigen/pipeline_omnigen.py | 2 - .../omnigen/test_pipeline_omnigen.py | 236 +++++++++--------- 6 files changed, 155 insertions(+), 165 deletions(-) diff --git a/scripts/convert_omnigen_to_diffusers.py b/scripts/convert_omnigen_to_diffusers.py index 46593530f8ff..96bc935633f0 100644 --- a/scripts/convert_omnigen_to_diffusers.py +++ b/scripts/convert_omnigen_to_diffusers.py @@ -1,6 +1,5 @@ import argparse import os -os.environ['HF_HUB_CACHE'] = "/share/shitao/downloaded_models2" import torch from huggingface_hub import snapshot_download @@ -45,7 +44,6 @@ def main(args): "t_embedder.mlp.2.weight": "t_embedder.linear_2.weight", "t_embedder.mlp.2.bias": "t_embedder.linear_2.bias", "llm.embed_tokens.weight": "embed_tokens.weight", - } converted_state_dict = {} @@ -63,7 +61,7 @@ def main(args): converted_state_dict[k[4:]] = v transformer = OmniGenTransformer2DModel( - rope_scaling = { + rope_scaling={ "long_factor": [ 1.0299999713897705, 1.0499999523162842, @@ -198,7 +196,7 @@ def main(args): ) parser.add_argument( - "--dump_path", default="/share/shitao/repos/OmniGen-v1-diffusers2", type=str, required=False, help="Path to the output pipeline." + "--dump_path", default="OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline." ) args = parser.parse_args() diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 73c068a177d1..2baaa5fcc15f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3414,7 +3414,6 @@ def __call__( return hidden_states - class OmniGenAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is @@ -3469,7 +3468,6 @@ def __call__( query, key = query.to(dtype), key.to(dtype) - # perform Grouped-qurey Attention (GQA) n_rep = attn.heads // kv_heads if n_rep > 1: @@ -3481,9 +3479,7 @@ def __call__( value = value.transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask - ) + hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) hidden_states = hidden_states.transpose(1, 2).to(dtype) hidden_states = hidden_states.reshape(bsz, q_len, attn.out_dim) hidden_states = attn.to_out[0](hidden_states) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index d5fff5c055ae..60a183212be7 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1013,8 +1013,18 @@ def apply_rotary_emb( sin = sin[None, None] elif len(cos.shape) == 3: # Used for OmniGen - cos = cos[:, :, None, :,] - sin = sin[:, :, None, :,] + cos = cos[ + :, + :, + None, + :, + ] + sin = sin[ + :, + :, + None, + :, + ] cos, sin = cos.to(x.device), sin.to(x.device) if revert_x_as_rotated: @@ -1034,7 +1044,7 @@ def apply_rotary_emb( x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") - + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out @@ -1047,8 +1057,6 @@ def apply_rotary_emb( return x_out.type_as(x) - - def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions): # TODO(aryan): rewrite def apply_1d_rope(tokens, pos, cos, sin): @@ -1098,12 +1106,9 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: class OmniGenSuScaledRotaryEmbedding(nn.Module): - def __init__(self, - dim, - max_position_embeddings=131072, - original_max_position_embeddings=4096, - base=10000, - rope_scaling=None): + def __init__( + self, dim, max_position_embeddings=131072, original_max_position_embeddings=4096, base=10000, rope_scaling=None + ): super().__init__() self.dim = dim @@ -1149,6 +1154,7 @@ def forward(self, x, position_ids): sin = emb.sin() * scaling_factor return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + class TimestepEmbedding(nn.Module): def __init__( self, @@ -1216,8 +1222,6 @@ def forward(self, timesteps): return t_emb - - class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index 4b1fca65a9cd..aeb9bb656eaf 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import torch import torch.utils.checkpoint @@ -22,8 +22,8 @@ from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers from ..attention import OmniGenFeedForward -from ..attention_processor import Attention, OmniGenAttnProcessor2_0 -from ..embeddings import OmniGenPatchEmbed, TimestepEmbedding, Timesteps, OmniGenSuScaledRotaryEmbedding +from ..attention_processor import Attention, AttentionProcessor, OmniGenAttnProcessor2_0 +from ..embeddings import OmniGenPatchEmbed, OmniGenSuScaledRotaryEmbedding, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, RMSNorm @@ -70,7 +70,6 @@ def __init__( self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) self.mlp = OmniGenFeedForward(hidden_size, intermediate_size) - def forward( self, hidden_states: torch.Tensor, @@ -108,7 +107,6 @@ def forward( return hidden_states - class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): """ The Transformer model introduced in OmniGen. @@ -161,8 +159,7 @@ def __init__( time_step_dim: int = 256, flip_sin_to_cos: bool = True, downscale_freq_shift: int = 0, - timestep_activation_fn: str = 'silu', - + timestep_activation_fn: str = "silu", ): super().__init__() self.in_channels = in_channels @@ -185,11 +182,13 @@ def __init__( self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True) self.embed_tokens = nn.Embedding(vocab_size, hidden_size, pad_token_id) - self.rotary_emb = OmniGenSuScaledRotaryEmbedding(hidden_size // num_attention_heads, - max_position_embeddings=max_position_embeddings, - original_max_position_embeddings=original_max_position_embeddings, - base=rope_base, - rope_scaling=rope_scaling) + self.rotary_emb = OmniGenSuScaledRotaryEmbedding( + hidden_size // num_attention_heads, + max_position_embeddings=max_position_embeddings, + original_max_position_embeddings=original_max_position_embeddings, + base=rope_base, + rope_scaling=rope_scaling, + ) self.layers = nn.ModuleList( [ @@ -220,7 +219,7 @@ def unpatchify(self, x, h, w): @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, OmniGenAttnProcessor2_0]: + def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with @@ -244,7 +243,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[OmniGenAttnProcessor2_0, Dict[str, OmniGenAttnProcessor2_0]]): + def set_attn_processor(self, processor: Union[OmniGenAttnProcessor2_0, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -381,7 +380,7 @@ def forward( height, width = hidden_states.size()[-2:] hidden_states = self.patch_embedding(hidden_states, is_input_image=False) num_tokens_for_output_image = hidden_states.size(1) - + time_token = self.time_token(self.time_proj(timestep).to(hidden_states.dtype)).unsqueeze(1) condition_tokens = self.get_multimodal_embeddings( @@ -393,7 +392,6 @@ def forward( inputs_embeds = torch.cat([condition_tokens, time_token, hidden_states], dim=1) else: inputs_embeds = torch.cat([time_token, hidden_states], dim=1) - batch_size, seq_length = inputs_embeds.shape[:2] position_ids = position_ids.view(-1, seq_length).long() @@ -410,19 +408,15 @@ def forward( cos, sin = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers: - hidden_states = decoder_layer( - hidden_states, - attention_mask=attention_mask, - rotary_emb=[cos, sin] - ) + hidden_states = decoder_layer(hidden_states, attention_mask=attention_mask, rotary_emb=[cos, sin]) hidden_states = self.norm(hidden_states) - + image_embedding = hidden_states[:, -num_tokens_for_output_image:] time_emb = self.t_embedder(self.time_proj(timestep).to(hidden_states.dtype)) x = self.proj_out(self.norm_out(image_embedding, temb=time_emb)) output = self.unpatchify(x, height, width) if not return_dict: - return (output, ) + return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 1e4ef5304f87..524949c25e0d 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -315,7 +315,6 @@ def num_timesteps(self): def interrupt(self): return self._interrupt - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -479,7 +478,6 @@ def __call__( latents, ) - # 8. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py index a542a8432aeb..b7a89eb54068 100644 --- a/tests/pipelines/omnigen/test_pipeline_omnigen.py +++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py @@ -34,113 +34,113 @@ def get_dummy_components(self): torch.manual_seed(0) transformer = OmniGenTransformer2DModel( - rope_scaling = { - "long_factor": [ - 1.0299999713897705, - 1.0499999523162842, - 1.0499999523162842, - 1.0799999237060547, - 1.2299998998641968, - 1.2299998998641968, - 1.2999999523162842, - 1.4499999284744263, - 1.5999999046325684, - 1.6499998569488525, - 1.8999998569488525, - 2.859999895095825, - 3.68999981880188, - 5.419999599456787, - 5.489999771118164, - 5.489999771118164, - 9.09000015258789, - 11.579999923706055, - 15.65999984741211, - 15.769999504089355, - 15.789999961853027, - 18.360000610351562, - 21.989999771118164, - 23.079999923706055, - 30.009998321533203, - 32.35000228881836, - 32.590003967285156, - 35.56000518798828, - 39.95000457763672, - 53.840003967285156, - 56.20000457763672, - 57.95000457763672, - 59.29000473022461, - 59.77000427246094, - 59.920005798339844, - 61.190006256103516, - 61.96000671386719, - 62.50000762939453, - 63.3700065612793, - 63.48000717163086, - 63.48000717163086, - 63.66000747680664, - 63.850006103515625, - 64.08000946044922, - 64.760009765625, - 64.80001068115234, - 64.81001281738281, - 64.81001281738281, - ], - "short_factor": [ - 1.05, - 1.05, - 1.05, - 1.1, - 1.1, - 1.1, - 1.2500000000000002, - 1.2500000000000002, - 1.4000000000000004, - 1.4500000000000004, - 1.5500000000000005, - 1.8500000000000008, - 1.9000000000000008, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.1000000000000005, - 2.1000000000000005, - 2.2, - 2.3499999999999996, - 2.3499999999999996, - 2.3499999999999996, - 2.3499999999999996, - 2.3999999999999995, - 2.3999999999999995, - 2.6499999999999986, - 2.6999999999999984, - 2.8999999999999977, - 2.9499999999999975, - 3.049999999999997, - 3.049999999999997, - 3.049999999999997, - ], - "type": "su", - }, - patch_size=2, - in_channels=4, - pos_embed_max_size=192, - ) + rope_scaling={ + "long_factor": [ + 1.0299999713897705, + 1.0499999523162842, + 1.0499999523162842, + 1.0799999237060547, + 1.2299998998641968, + 1.2299998998641968, + 1.2999999523162842, + 1.4499999284744263, + 1.5999999046325684, + 1.6499998569488525, + 1.8999998569488525, + 2.859999895095825, + 3.68999981880188, + 5.419999599456787, + 5.489999771118164, + 5.489999771118164, + 9.09000015258789, + 11.579999923706055, + 15.65999984741211, + 15.769999504089355, + 15.789999961853027, + 18.360000610351562, + 21.989999771118164, + 23.079999923706055, + 30.009998321533203, + 32.35000228881836, + 32.590003967285156, + 35.56000518798828, + 39.95000457763672, + 53.840003967285156, + 56.20000457763672, + 57.95000457763672, + 59.29000473022461, + 59.77000427246094, + 59.920005798339844, + 61.190006256103516, + 61.96000671386719, + 62.50000762939453, + 63.3700065612793, + 63.48000717163086, + 63.48000717163086, + 63.66000747680664, + 63.850006103515625, + 64.08000946044922, + 64.760009765625, + 64.80001068115234, + 64.81001281738281, + 64.81001281738281, + ], + "short_factor": [ + 1.05, + 1.05, + 1.05, + 1.1, + 1.1, + 1.1, + 1.2500000000000002, + 1.2500000000000002, + 1.4000000000000004, + 1.4500000000000004, + 1.5500000000000005, + 1.8500000000000008, + 1.9000000000000008, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.1000000000000005, + 2.1000000000000005, + 2.2, + 2.3499999999999996, + 2.3499999999999996, + 2.3499999999999996, + 2.3499999999999996, + 2.3999999999999995, + 2.3999999999999995, + 2.6499999999999986, + 2.6999999999999984, + 2.8999999999999977, + 2.9499999999999975, + 3.049999999999997, + 3.049999999999997, + 3.049999999999997, + ], + "type": "su", + }, + patch_size=2, + in_channels=4, + pos_embed_max_size=192, + ) torch.manual_seed(0) vae = AutoencoderKL( @@ -233,18 +233,18 @@ def test_omnigen_inference(self): print(image_slice) expected_slice = np.array( [ - [0.1783447, 0.16772744, 0.14339337], - [0.17066911, 0.15521264, 0.13757327], - [0.17072496, 0.15531206, 0.13524258], - [0.16746324, 0.1564025, 0.13794944], - [0.16490817, 0.15258026, 0.13697758], - [0.16971767, 0.15826806, 0.13928896], - [0.16782972, 0.15547255, 0.13783783], - [0.16464645, 0.15281534, 0.13522372], - [0.16535294, 0.15301755, 0.13526791], - [0.16365296, 0.15092957, 0.13443318], + [0.1783447, 0.16772744, 0.14339337], + [0.17066911, 0.15521264, 0.13757327], + [0.17072496, 0.15531206, 0.13524258], + [0.16746324, 0.1564025, 0.13794944], + [0.16490817, 0.15258026, 0.13697758], + [0.16971767, 0.15826806, 0.13928896], + [0.16782972, 0.15547255, 0.13783783], + [0.16464645, 0.15281534, 0.13522372], + [0.16535294, 0.15301755, 0.13526791], + [0.16365296, 0.15092957, 0.13443318], ], - dtype = np.float32, + dtype=np.float32, ) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) From 48fd390b0163e466fc32ed7e7fd59235f878cbfc Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Thu, 19 Dec 2024 13:42:29 +0800 Subject: [PATCH 27/55] update test cases --- tests/pipelines/omnigen/test_pipeline_omnigen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py index b7a89eb54068..43842c346551 100644 --- a/tests/pipelines/omnigen/test_pipeline_omnigen.py +++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py @@ -230,7 +230,7 @@ def test_omnigen_inference(self): image = pipe(**inputs).images[0] image_slice = image[0, :10, :10] - print(image_slice) + expected_slice = np.array( [ [0.1783447, 0.16772744, 0.14339337], From 78431e1c42b081129544680e3e359383772f51c4 Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Fri, 20 Dec 2024 10:36:04 +0800 Subject: [PATCH 28/55] Update docs/source/en/api/pipelines/omnigen.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/pipelines/omnigen.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/pipelines/omnigen.md b/docs/source/en/api/pipelines/omnigen.md index f6b45b1b014d..0b826f182edd 100644 --- a/docs/source/en/api/pipelines/omnigen.md +++ b/docs/source/en/api/pipelines/omnigen.md @@ -79,9 +79,9 @@ image = pipe( image ``` -OmniGen supports for multimodal inputs. +OmniGen supports multimodal inputs. When the input includes an image, you need to add a placeholder `<|image_1|>` in the text prompt to represent the image. -It is recommended to enable 'use_input_image_size_as_output' to keep the edited image the same size as the original image. +It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image. ```py prompt="<|image_1|> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola." From d99a9f866f58144d640061b40feb34dbf0c9283a Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Fri, 20 Dec 2024 10:36:38 +0800 Subject: [PATCH 29/55] Update docs/source/en/using-diffusers/omnigen.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/using-diffusers/omnigen.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md index 45945090abbe..1928ab11517d 100644 --- a/docs/source/en/using-diffusers/omnigen.md +++ b/docs/source/en/using-diffusers/omnigen.md @@ -63,9 +63,9 @@ image ## Image edit -OmniGen supports for multimodal inputs. +OmniGen supports multimodal inputs. When the input includes an image, you need to add a placeholder `<|image_1|>` in the text prompt to represent the image. -It is recommended to enable 'use_input_image_size_as_output' to keep the edited image the same size as the original image. +It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image. ```py import torch From 61d802a947ab51a9983d1c444a2f8f870030c1e0 Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Fri, 20 Dec 2024 10:36:50 +0800 Subject: [PATCH 30/55] Update docs/source/en/using-diffusers/omnigen.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/using-diffusers/omnigen.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md index 1928ab11517d..e86404a1aea2 100644 --- a/docs/source/en/using-diffusers/omnigen.md +++ b/docs/source/en/using-diffusers/omnigen.md @@ -295,7 +295,7 @@ image ## Optimization when inputting multiple images -For text-to-image task, OmniGen requires minimal memory and time costs (9GB memory and 31s for a 1024*1024 image on A800 GPU). +For text-to-image task, OmniGen requires minimal memory and time costs (9GB memory and 31s for a 1024x1024 image on A800 GPU). However, when using input images, the computational cost increases. Here are some guidelines to help you reduce computational costs when input multiple images. The experiments are conducted on A800 GPU and input two images to OmniGen. From 8d6a35e48a297665a930822e2c8ab43cfbfb9ec5 Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Fri, 20 Dec 2024 10:37:05 +0800 Subject: [PATCH 31/55] Update docs/source/en/using-diffusers/omnigen.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/using-diffusers/omnigen.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md index e86404a1aea2..26677694ae65 100644 --- a/docs/source/en/using-diffusers/omnigen.md +++ b/docs/source/en/using-diffusers/omnigen.md @@ -298,7 +298,7 @@ image For text-to-image task, OmniGen requires minimal memory and time costs (9GB memory and 31s for a 1024x1024 image on A800 GPU). However, when using input images, the computational cost increases. -Here are some guidelines to help you reduce computational costs when input multiple images. The experiments are conducted on A800 GPU and input two images to OmniGen. +Here are some guidelines to help you reduce computational costs when inputting multiple images. The experiments are conducted on an A800 GPU with two input images. | Method | Memory Usage | From f8e645ba5b98d1b7c6b97a74653a7a27d567c113 Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Fri, 20 Dec 2024 10:37:27 +0800 Subject: [PATCH 32/55] Update docs/source/en/using-diffusers/omnigen.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/using-diffusers/omnigen.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md index 26677694ae65..c9a53c4c8e95 100644 --- a/docs/source/en/using-diffusers/omnigen.md +++ b/docs/source/en/using-diffusers/omnigen.md @@ -100,7 +100,7 @@ image
-OmniGen has some interesting features, such as the ability to infer user needs, as shown in the example below. +OmniGen has some interesting features, such as visual reasoning, as shown in the example below. ```py prompt="If the woman is thirsty, what should she take? Find it in the image and highlight it in blue. <|image_1|>" input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")] From 85cdeb94808a03717c4f3eebce552eaf104229c0 Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Fri, 20 Dec 2024 10:37:56 +0800 Subject: [PATCH 33/55] Update docs/source/en/using-diffusers/omnigen.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/using-diffusers/omnigen.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md index c9a53c4c8e95..8779ad69715f 100644 --- a/docs/source/en/using-diffusers/omnigen.md +++ b/docs/source/en/using-diffusers/omnigen.md @@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License. --> # OmniGen -OmniGen is an image generation model. Unlike existing text-to-image models, OmniGen is designed to handle a variety of tasks (e.g., text-to-image, image editing, controllable generation) within a single model. It has the following features: +OmniGen is an image generation model. Unlike existing text-to-image models, OmniGen is a single model designed to handle a variety of tasks (e.g., text-to-image, image editing, controllable generation). It has the following features: - Minimalist model architecture, consisting of only a VAE and a transformer module, for joint modeling of text and images. - Support for multimodal inputs. It can process any text-image mixed data as instructions for image generation, rather than relying solely on text. From 3565837faaba7c1282741422eec99682004e2daa Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Fri, 20 Dec 2024 15:56:03 +0800 Subject: [PATCH 34/55] update docs --- docs/source/en/using-diffusers/omnigen.md | 24 ++++++++--------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md index 8779ad69715f..a3d98e4e60cc 100644 --- a/docs/source/en/using-diffusers/omnigen.md +++ b/docs/source/en/using-diffusers/omnigen.md @@ -300,23 +300,15 @@ However, when using input images, the computational cost increases. Here are some guidelines to help you reduce computational costs when inputting multiple images. The experiments are conducted on an A800 GPU with two input images. +Like other pipelines, you can reduce memory usage by offloading the model: `pipe.enable_model_cpu_offload()` or `pipe.enable_sequential_cpu_offload() `. +In OmniGen, you can also decrease computational overhead by reducing the `max_input_image_size`. +The memory consumption for different image sizes is shown in the table below: -| Method | Memory Usage | -|---------------------------------------------|--------------| -| pipe.to("cuda") | 31GB | -| pipe.enable_model_cpu_offload() | 28GB | -| pipe.enable_sequential_cpu_offload() | 11GB | +| Method | Memory Usage | +|---------------------------|--------------| +| max_input_image_size=1024 | 40GB | +| max_input_image_size=512 | 17GB | +| max_input_image_size=256 | 14GB | -- `pipe.enable_model_cpu_offload()`: - - Without enabling cpu offloading, memory usage is `31 GB` - - With enabling cpu offloading, memory usage is `28 GB` - -- `pipe.enable_transformer_block_cpu_offload()`: - - Offload transformer block to reduce memory usage - - When enabled, memory usage is under `25 GB` - -- `pipe.enable_sequential_cpu_offload()`: - - Significantly reduce memory usage at the cost of slow inference - - When enabled, memory usage is under `11 GB` From 0ccca15c757aa581ada212d440e05c1399ab37ce Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Sun, 29 Dec 2024 15:34:59 +0800 Subject: [PATCH 35/55] typo --- src/diffusers/pipelines/omnigen/pipeline_omnigen.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 524949c25e0d..342e9b180ed1 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -173,14 +173,14 @@ def __init__( ) self.default_sample_size = 128 - def encod_input_iamges( + def encod_input_images( self, input_pixel_values: List[torch.Tensor], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): """ - get the continues embedding of input images by VAE + get the continue embedding of input images by VAE Args: input_pixel_values: normlized pixel of input images @@ -455,7 +455,7 @@ def __call__( processed_data["position_ids"] = processed_data["position_ids"].to(device) # 4. Encode input images - input_img_latents = self.encod_input_iamges(processed_data["input_pixel_values"], device=device) + input_img_latents = self.encod_input_images(processed_data["input_pixel_values"], device=device) # 5. Prepare timesteps sigmas = np.linspace(1, 0, num_inference_steps + 1)[:num_inference_steps] From d014f9578d49f0bcd1e0b772a8b8fafeb5428238 Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Sat, 8 Feb 2025 12:22:52 +0800 Subject: [PATCH 36/55] Update src/diffusers/models/embeddings.py Co-authored-by: hlky --- src/diffusers/models/embeddings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 60a183212be7..038d9fb32cfb 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -376,9 +376,9 @@ def __init__( self.pos_embed_max_size = pos_embed_max_size pos_embed = get_2d_sincos_pos_embed( - embed_dim, self.pos_embed_max_size, base_size=base_size, interpolation_scale=self.interpolation_scale + embed_dim, self.pos_embed_max_size, base_size=base_size, interpolation_scale=self.interpolation_scale, output_type="pt" ) - self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True) + self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True) def cropped_pos_embed(self, height, width): """Crops positional embeddings for SD3 compatibility.""" From 753daec841d23355f4b312566f30e4bb6711ee3d Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Sat, 8 Feb 2025 12:23:15 +0800 Subject: [PATCH 37/55] Update src/diffusers/models/attention.py Co-authored-by: hlky --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 780f2e22b391..2e0785d1df2c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -654,7 +654,7 @@ def __init__( self.activation_fn = nn.SiLU() - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: up_states = self.gate_up_proj(hidden_states) gate, up_states = up_states.chunk(2, dim=-1) From 3d30a2ad9e84968ca8be5351ad39a54036ffe1c2 Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Sat, 8 Feb 2025 12:23:28 +0800 Subject: [PATCH 38/55] Update src/diffusers/models/transformers/transformer_omnigen.py Co-authored-by: hlky --- src/diffusers/models/transformers/transformer_omnigen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index aeb9bb656eaf..f35305d5b748 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -332,7 +332,7 @@ def forward( The [`OmniGenTransformer2DModel`] forward method. Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. timestep (`torch.LongTensor`): Used to indicate denoising step. From 9d1580af0daba76508b9322e86dc30bebb0aeafc Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Sat, 8 Feb 2025 12:23:48 +0800 Subject: [PATCH 39/55] Update src/diffusers/models/transformers/transformer_omnigen.py Co-authored-by: hlky --- src/diffusers/models/transformers/transformer_omnigen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index f35305d5b748..b163a9933899 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -338,7 +338,7 @@ def forward( Used to indicate denoising step. input_ids (`torch.LongTensor`): token ids - input_img_latents (`torch.FloatTensor`): + input_img_latents (`torch.Tensor`): encoded image latents by VAE input_image_sizes (`dict`): the indices of the input_img_latents in the input_ids From 6a587464eb6f085da968519e53718c8aa1bac92c Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Sat, 8 Feb 2025 12:24:00 +0800 Subject: [PATCH 40/55] Update src/diffusers/models/transformers/transformer_omnigen.py Co-authored-by: hlky --- src/diffusers/models/transformers/transformer_omnigen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index b163a9933899..78bd97cb4fa5 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -342,7 +342,7 @@ def forward( encoded image latents by VAE input_image_sizes (`dict`): the indices of the input_img_latents in the input_ids - attention_mask (`torch.FloatTensor`): + attention_mask (`torch.Tensor`): mask for self-attention position_ids (`torch.LongTensor`): id to represent position From 2b464c88e9a2b152f337011d2930dafea991b573 Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Sat, 8 Feb 2025 12:25:21 +0800 Subject: [PATCH 41/55] Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py Co-authored-by: hlky --- src/diffusers/pipelines/omnigen/pipeline_omnigen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 342e9b180ed1..9bc9f356db6a 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -321,7 +321,7 @@ def __call__( self, prompt: Union[str, List[str]], input_images: Optional[ - Union[List[str], List[PIL.Image.Image], List[List[str]], List[List[PIL.Image.Image]]] + PipelineImageInput ] = None, height: Optional[int] = None, width: Optional[int] = None, From 78881196cc84449476569b1254cb471425d0a160 Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Sat, 8 Feb 2025 12:25:35 +0800 Subject: [PATCH 42/55] Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py Co-authored-by: hlky --- src/diffusers/pipelines/omnigen/pipeline_omnigen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 9bc9f356db6a..3c0ef1729ac0 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -20,7 +20,7 @@ import torch from transformers import LlamaTokenizer -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import OmniGenTransformer2DModel From 39148c3ac3241adeb711ec1e4de4f2ce23bae337 Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Sat, 8 Feb 2025 12:25:54 +0800 Subject: [PATCH 43/55] Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py Co-authored-by: hlky --- src/diffusers/pipelines/omnigen/pipeline_omnigen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 3c0ef1729ac0..a53fee292f48 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -333,7 +333,7 @@ def __call__( use_input_image_size_as_output: bool = False, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, + latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, From 52a6f9e00cdbf853462792c1b4d3c6c3bc186b2a Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Sat, 8 Feb 2025 12:33:03 +0800 Subject: [PATCH 44/55] Update tests/pipelines/omnigen/test_pipeline_omnigen.py Co-authored-by: hlky --- tests/pipelines/omnigen/test_pipeline_omnigen.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py index 43842c346551..73283de0a0e3 100644 --- a/tests/pipelines/omnigen/test_pipeline_omnigen.py +++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py @@ -156,7 +156,6 @@ def get_dummy_components(self): scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - # tokenizer = AutoTokenizer.from_pretrained("Shitao/OmniGen-v1") components = { "transformer": transformer.eval(), From aeea57a0d6f156c8337622b4f81499771f28e7c7 Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Sat, 8 Feb 2025 12:33:23 +0800 Subject: [PATCH 45/55] Update tests/pipelines/omnigen/test_pipeline_omnigen.py Co-authored-by: hlky --- tests/pipelines/omnigen/test_pipeline_omnigen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py index 73283de0a0e3..70e9cd0a53b2 100644 --- a/tests/pipelines/omnigen/test_pipeline_omnigen.py +++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py @@ -158,8 +158,8 @@ def get_dummy_components(self): tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") components = { - "transformer": transformer.eval(), - "vae": vae.eval(), + "transformer": transformer, + "vae": vae, "scheduler": scheduler, "tokenizer": tokenizer, } From 6b1177bc15c0299a0bb26f9c38801540ff3cf959 Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Sat, 8 Feb 2025 12:33:47 +0800 Subject: [PATCH 46/55] Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py Co-authored-by: hlky --- src/diffusers/pipelines/omnigen/pipeline_omnigen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index a53fee292f48..9431b77fa88c 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -348,7 +348,7 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If the input includes images, need to add placeholders `<|image_i|>` in the prompt to indicate the position of the i-th images. - input_images (`List[str]` or `List[List[str]]`, *optional*): + input_images (`PipelineImageInput`, *optional*): The list of input images. We will replace the "<|image_i|>" in prompt with the i-th image in list. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. From 7003a80245a91bad55921fd8b82b786d02cd6fb1 Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Sat, 8 Feb 2025 12:34:37 +0800 Subject: [PATCH 47/55] Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py Co-authored-by: hlky --- src/diffusers/pipelines/omnigen/pipeline_omnigen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 9431b77fa88c..d82519b80bfd 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -173,7 +173,7 @@ def __init__( ) self.default_sample_size = 128 - def encod_input_images( + def encode_input_images( self, input_pixel_values: List[torch.Tensor], device: Optional[torch.device] = None, From 792c3e6b4a42b1f04061341051b9e9a43665976f Mon Sep 17 00:00:00 2001 From: Shitao Xiao <2906698981@qq.com> Date: Sat, 8 Feb 2025 12:34:46 +0800 Subject: [PATCH 48/55] Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py Co-authored-by: hlky --- src/diffusers/pipelines/omnigen/pipeline_omnigen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index d82519b80bfd..fe8ebabc0daf 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -455,7 +455,7 @@ def __call__( processed_data["position_ids"] = processed_data["position_ids"].to(device) # 4. Encode input images - input_img_latents = self.encod_input_images(processed_data["input_pixel_values"], device=device) + input_img_latents = self.encode_input_images(processed_data["input_pixel_values"], device=device) # 5. Prepare timesteps sigmas = np.linspace(1, 0, num_inference_steps + 1)[:num_inference_steps] From f5e3f0b64c0ca905f5a3547fd261fd8a63939103 Mon Sep 17 00:00:00 2001 From: staoxiao <2906698981@qq.com> Date: Sat, 8 Feb 2025 14:50:36 +0800 Subject: [PATCH 49/55] consistent attention processor --- src/diffusers/models/attention_processor.py | 16 +++----------- src/diffusers/models/embeddings.py | 15 ++----------- .../pipelines/omnigen/pipeline_omnigen.py | 21 +++---------------- 3 files changed, 8 insertions(+), 44 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 54d5f09e61d7..c02cfa628275 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3993,9 +3993,9 @@ def __call__( # Get key-value heads kv_heads = inner_dim // head_dim - query = query.view(batch_size, -1, attn.heads, head_dim) - key = key.view(batch_size, -1, kv_heads, head_dim) - value = value.view(batch_size, -1, kv_heads, head_dim) + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) # Apply RoPE if needed if query_rotary_emb is not None: @@ -4005,16 +4005,6 @@ def __call__( query, key = query.to(dtype), key.to(dtype) - # perform Grouped-qurey Attention (GQA) - n_rep = attn.heads // kv_heads - if n_rep > 1: - key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) hidden_states = hidden_states.transpose(1, 2).to(dtype) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index dad9e989ce62..0a5d29762ef6 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1290,19 +1290,8 @@ def apply_rotary_emb( cos = cos[None, None] sin = sin[None, None] elif len(cos.shape) == 3: - # Used for OmniGen - cos = cos[ - :, - :, - None, - :, - ] - sin = sin[ - :, - :, - None, - :, - ] + cos = cos[:, None] + sin = sin[:, None] cos, sin = cos.to(x.device), sin.to(x.device) if revert_x_as_rotated: diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index fe8ebabc0daf..609155e377a6 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -229,20 +229,6 @@ def check_inputs( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) - @staticmethod - def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to @@ -272,6 +258,7 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents def prepare_latents( self, batch_size, @@ -320,9 +307,7 @@ def interrupt(self): def __call__( self, prompt: Union[str, List[str]], - input_images: Optional[ - PipelineImageInput - ] = None, + input_images: Union[PipelineImageInput, List[PipelineImageInput]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -348,7 +333,7 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If the input includes images, need to add placeholders `<|image_i|>` in the prompt to indicate the position of the i-th images. - input_images (`PipelineImageInput`, *optional*): + input_images (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*): The list of input images. We will replace the "<|image_i|>" in prompt with the i-th image in list. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. From 3541ab86b7bb0bb216972da9fd772e57a821aac3 Mon Sep 17 00:00:00 2001 From: staoxiao <2906698981@qq.com> Date: Sat, 8 Feb 2025 19:32:27 +0800 Subject: [PATCH 50/55] updata --- src/diffusers/models/attention.py | 31 -- src/diffusers/models/attention_processor.py | 62 ---- src/diffusers/models/embeddings.py | 179 +--------- .../transformers/transformer_omnigen.py | 310 ++++++++++++++++-- .../pipelines/omnigen/pipeline_omnigen.py | 10 +- .../pipelines/omnigen/processor_omnigen.py | 10 + .../test_models_transformer_omnigen.py | 88 +++++ 7 files changed, 401 insertions(+), 289 deletions(-) create mode 100644 tests/models/transformers/test_models_transformer_omnigen.py diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index d7aceff2376d..4d1dae879f11 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -639,37 +639,6 @@ def forward(self, x): return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x)) -class OmniGenFeedForward(nn.Module): - r""" - A feed-forward layer for OmniGen. - - Parameters: - hidden_size (`int`): - The dimensionality of the hidden layers in the model. This parameter determines the width of the model's - hidden representations. - intermediate_size (`int`): The intermediate dimension of the feedforward layer. - """ - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - ): - super().__init__() - self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) - - self.activation_fn = nn.SiLU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - up_states = self.gate_up_proj(hidden_states) - - gate, up_states = up_states.chunk(2, dim=-1) - up_states = up_states * self.activation_fn(gate) - - return self.down_proj(up_states) - - @maybe_allow_in_graph class TemporalBasicTransformerBlock(nn.Module): r""" diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c02cfa628275..5d873baf8fbb 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3951,68 +3951,6 @@ def __call__( return hidden_states -class OmniGenAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is - used in the OmniGen model. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - query_rotary_emb: Optional[torch.Tensor] = None, - key_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - from .embeddings import apply_rotary_emb - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = hidden_states.shape - - # Get Query-Key-Value Pair - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - bsz, q_len, query_dim = query.size() - inner_dim = key.shape[-1] - head_dim = query_dim // attn.heads - dtype = query.dtype - - # Get key-value heads - kv_heads = inner_dim // head_dim - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) - - # Apply RoPE if needed - if query_rotary_emb is not None: - query = apply_rotary_emb(query, query_rotary_emb, revert_x_as_rotated=True) - if key_rotary_emb is not None: - key = apply_rotary_emb(key, key_rotary_emb, revert_x_as_rotated=True) - - query, key = query.to(dtype), key.to(dtype) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) - hidden_states = hidden_states.transpose(1, 2).to(dtype) - hidden_states = hidden_states.reshape(bsz, q_len, attn.out_dim) - hidden_states = attn.to_out[0](hidden_states) - return hidden_states - - class PAGHunyuanAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 0a5d29762ef6..bd3237c24c1c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -569,101 +569,6 @@ def forward(self, latent): return (latent + pos_embed).to(latent.dtype) -class OmniGenPatchEmbed(nn.Module): - """2D Image to Patch Embedding with support for OmniGen.""" - - def __init__( - self, - patch_size: int = 2, - in_channels: int = 4, - embed_dim: int = 768, - bias: bool = True, - interpolation_scale: float = 1, - pos_embed_max_size: int = 192, - base_size: int = 64, - ): - super().__init__() - - self.output_image_proj = nn.Conv2d( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias - ) - self.input_image_proj = nn.Conv2d( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias - ) - - self.patch_size = patch_size - self.interpolation_scale = interpolation_scale - self.pos_embed_max_size = pos_embed_max_size - - pos_embed = get_2d_sincos_pos_embed( - embed_dim, self.pos_embed_max_size, base_size=base_size, interpolation_scale=self.interpolation_scale, output_type="pt" - ) - self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True) - - def cropped_pos_embed(self, height, width): - """Crops positional embeddings for SD3 compatibility.""" - if self.pos_embed_max_size is None: - raise ValueError("`pos_embed_max_size` must be set for cropping.") - - height = height // self.patch_size - width = width // self.patch_size - if height > self.pos_embed_max_size: - raise ValueError( - f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." - ) - if width > self.pos_embed_max_size: - raise ValueError( - f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." - ) - - top = (self.pos_embed_max_size - height) // 2 - left = (self.pos_embed_max_size - width) // 2 - spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) - spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] - spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) - return spatial_pos_embed - - def patch_embeddings(self, latent, is_input_image: bool): - if is_input_image: - latent = self.input_image_proj(latent) - else: - latent = self.output_image_proj(latent) - latent = latent.flatten(2).transpose(1, 2) - return latent - - def forward(self, latent: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None): - """ - Args: - latent: encoded image latents - is_input_image: use input_image_proj or output_image_proj - padding_latent: - When sizes of target images are inconsistent, use `padding_latent` to maintain consistent sequence - length. - - Returns: torch.Tensor - - """ - if isinstance(latent, list): - if padding_latent is None: - padding_latent = [None] * len(latent) - patched_latents = [] - for sub_latent, padding in zip(latent, padding_latent): - height, width = sub_latent.shape[-2:] - sub_latent = self.patch_embeddings(sub_latent, is_input_image) - pos_embed = self.cropped_pos_embed(height, width) - sub_latent = sub_latent + pos_embed - if padding is not None: - sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2) - patched_latents.append(sub_latent) - else: - height, width = latent.shape[-2:] - pos_embed = self.cropped_pos_embed(height, width) - latent = self.patch_embeddings(latent, is_input_image) - patched_latents = latent + pos_embed - - return patched_latents - - class LuminaPatchEmbed(nn.Module): """ 2D Image to Patch Embedding with support for Lumina-T2X @@ -1268,7 +1173,6 @@ def apply_rotary_emb( freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], use_real: bool = True, use_real_unbind_dim: int = -1, - revert_x_as_rotated: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings @@ -1286,31 +1190,20 @@ def apply_rotary_emb( """ if use_real: cos, sin = freqs_cis # [S, D] - if len(cos.shape) == 2: - cos = cos[None, None] - sin = sin[None, None] - elif len(cos.shape) == 3: - cos = cos[:, None] - sin = sin[:, None] + cos = cos[None, None] + sin = sin[None, None] cos, sin = cos.to(x.device), sin.to(x.device) - if revert_x_as_rotated: - # Rotates half the hidden dims of the input. this rorate function is widely used in LLM, e.g. Llama, Phi3, etc. - # Used for OmniGen - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - x_rotated = torch.cat((-x2, x1), dim=-1) + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: - if use_real_unbind_dim == -1: - # Used for flux, cogvideox, hunyuan-dit - x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] - x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) - elif use_real_unbind_dim == -2: - # Used for Stable Audio - x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] - x_rotated = torch.cat([-x_imag, x_real], dim=-1) - else: - raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) @@ -1373,56 +1266,6 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: return freqs_cos, freqs_sin -class OmniGenSuScaledRotaryEmbedding(nn.Module): - def __init__( - self, dim, max_position_embeddings=131072, original_max_position_embeddings=4096, base=10000, rope_scaling=None - ): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) - - self.short_factor = rope_scaling["short_factor"] - self.long_factor = rope_scaling["long_factor"] - self.original_max_position_embeddings = original_max_position_embeddings - - @torch.no_grad() - def forward(self, x, position_ids): - seq_len = torch.max(position_ids) + 1 - if seq_len > self.original_max_position_embeddings: - ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) - else: - ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) - - inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim - self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) - - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - - scale = self.max_position_embeddings / self.original_max_position_embeddings - if scale <= 1.0: - scaling_factor = 1.0 - else: - scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) - - cos = emb.cos() * scaling_factor - sin = emb.sin() * scaling_factor - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - class TimestepEmbedding(nn.Module): def __init__( self, diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index 78bd97cb4fa5..3a02a21dcf11 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -12,18 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +import math +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn +import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers -from ..attention import OmniGenFeedForward -from ..attention_processor import Attention, AttentionProcessor, OmniGenAttnProcessor2_0 -from ..embeddings import OmniGenPatchEmbed, OmniGenSuScaledRotaryEmbedding, TimestepEmbedding, Timesteps +from ..attention_processor import Attention, AttentionProcessor +from ..embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, RMSNorm @@ -32,6 +33,272 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class OmniGenFeedForward(nn.Module): + r""" + A feed-forward layer for OmniGen. + + Parameters: + hidden_size (`int`): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + intermediate_size (`int`): The intermediate dimension of the feedforward layer. + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + ): + super().__init__() + self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + + self.activation_fn = nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +class OmniGenPatchEmbed(nn.Module): + """2D Image to Patch Embedding with support for OmniGen.""" + + def __init__( + self, + patch_size: int = 2, + in_channels: int = 4, + embed_dim: int = 768, + bias: bool = True, + interpolation_scale: float = 1, + pos_embed_max_size: int = 192, + base_size: int = 64, + ): + super().__init__() + + self.output_image_proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + self.input_image_proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + + self.patch_size = patch_size + self.interpolation_scale = interpolation_scale + self.pos_embed_max_size = pos_embed_max_size + + pos_embed = get_2d_sincos_pos_embed( + embed_dim, self.pos_embed_max_size, base_size=base_size, interpolation_scale=self.interpolation_scale, output_type="pt" + ) + self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True) + + def cropped_pos_embed(self, height, width): + """Crops positional embeddings for SD3 compatibility.""" + if self.pos_embed_max_size is None: + raise ValueError("`pos_embed_max_size` must be set for cropping.") + + height = height // self.patch_size + width = width // self.patch_size + if height > self.pos_embed_max_size: + raise ValueError( + f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." + ) + if width > self.pos_embed_max_size: + raise ValueError( + f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." + ) + + top = (self.pos_embed_max_size - height) // 2 + left = (self.pos_embed_max_size - width) // 2 + spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) + spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + return spatial_pos_embed + + def patch_embeddings(self, latent, is_input_image: bool): + if is_input_image: + latent = self.input_image_proj(latent) + else: + latent = self.output_image_proj(latent) + latent = latent.flatten(2).transpose(1, 2) + return latent + + def forward(self, latent: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None): + """ + Args: + latent: encoded image latents + is_input_image: use input_image_proj or output_image_proj + padding_latent: + When sizes of target images are inconsistent, use `padding_latent` to maintain consistent sequence + length. + + Returns: torch.Tensor + + """ + if isinstance(latent, list): + if padding_latent is None: + padding_latent = [None] * len(latent) + patched_latents = [] + for sub_latent, padding in zip(latent, padding_latent): + height, width = sub_latent.shape[-2:] + sub_latent = self.patch_embeddings(sub_latent, is_input_image) + pos_embed = self.cropped_pos_embed(height, width) + sub_latent = sub_latent + pos_embed + if padding is not None: + sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2) + patched_latents.append(sub_latent) + else: + height, width = latent.shape[-2:] + pos_embed = self.cropped_pos_embed(height, width) + latent = self.patch_embeddings(latent, is_input_image) + patched_latents = latent + pos_embed + + return patched_latents + + +class OmniGenSuScaledRotaryEmbedding(nn.Module): + def __init__( + self, dim, max_position_embeddings=131072, original_max_position_embeddings=4096, base=10000, rope_scaling=None + ): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + self.short_factor = rope_scaling["short_factor"] + self.long_factor = rope_scaling["long_factor"] + self.original_max_position_embeddings = original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + + cos, sin = freqs_cis # [S, D] + if len(cos.shape) == 2: + cos = cos[None, None] + sin = sin[None, None] + elif len(cos.shape) == 3: + cos = cos[:, None] + sin = sin[:, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + # Rotates half the hidden dims of the input. this rorate function is widely used in LLM, e.g. Llama, Phi3, etc. + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + x_rotated = torch.cat((-x2, x1), dim=-1) + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + +class OmniGenAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the OmniGen model. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + bsz, q_len, query_dim = query.size() + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + query, key = query.to(dtype), key.to(dtype) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + hidden_states = hidden_states.transpose(1, 2).to(dtype) + hidden_states = hidden_states.reshape(bsz, q_len, attn.out_dim) + hidden_states = attn.to_out[0](hidden_states) + return hidden_states + + class OmniGenBlock(nn.Module): """ A LuminaNextDiTBlock for LuminaNextDiT2DModel. @@ -74,7 +341,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, - rotary_emb: torch.Tensor, + image_rotary_emb: torch.Tensor, ): """ Perform a forward pass through the LuminaNextDiTBlock. @@ -82,7 +349,7 @@ def forward( Parameters: hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock. attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask. - rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies. + image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies. """ residual = hidden_states @@ -93,8 +360,7 @@ def forward( hidden_states=hidden_states, encoder_hidden_states=hidden_states, attention_mask=attention_mask, - query_rotary_emb=rotary_emb, - key_rotary_emb=rotary_emb, + image_rotary_emb=image_rotary_emb, ) hidden_states = residual + attn_outputs @@ -137,6 +403,7 @@ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): """ _supports_gradient_checkpointing = True + _no_split_modules = ["OmniGenBlock"] @register_to_config def __init__( @@ -204,6 +471,8 @@ def __init__( ) self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.gradient_checkpointing = False + def unpatchify(self, x, h, w): """ x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) @@ -277,10 +546,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def get_multimodal_embeddings( self, input_ids: torch.Tensor, @@ -319,7 +584,7 @@ def get_multimodal_embeddings( def forward( self, hidden_states: torch.Tensor, - timestep: Union[int, float, torch.LongTensor], + timestep: Union[int, float, torch.FloatTensor], input_ids: torch.Tensor, input_img_latents: List[torch.Tensor], input_image_sizes: Dict[int, List[int]], @@ -334,7 +599,7 @@ def forward( Args: hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. - timestep (`torch.LongTensor`): + timestep (`torch.FloatTensor`): Used to indicate denoising step. input_ids (`torch.LongTensor`): token ids @@ -406,16 +671,21 @@ def forward( hidden_states = inputs_embeds - cos, sin = self.rotary_emb(hidden_states, position_ids) + image_rotary_emb = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers: - hidden_states = decoder_layer(hidden_states, attention_mask=attention_mask, rotary_emb=[cos, sin]) + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(decoder_layer, hidden_states, attention_mask, image_rotary_emb) + else: + hidden_states = decoder_layer(hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb) hidden_states = self.norm(hidden_states) - image_embedding = hidden_states[:, -num_tokens_for_output_image:] - time_emb = self.t_embedder(self.time_proj(timestep).to(hidden_states.dtype)) - x = self.proj_out(self.norm_out(image_embedding, temb=time_emb)) - output = self.unpatchify(x, height, width) + hidden_states = hidden_states[:, -num_tokens_for_output_image:] + timestep_proj = self.time_proj(timestep) + temb = self.t_embedder(timestep_proj.type_as(hidden_states)) + hidden_states = self.norm_out(hidden_states, temb=temb) + hidden_states = self.proj_out(hidden_states) + output = self.unpatchify(hidden_states, height, width) if not return_dict: return (output,) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 609155e377a6..25ca933524b5 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -21,7 +21,6 @@ from transformers import LlamaTokenizer from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import OmniGenTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -121,8 +120,6 @@ def retrieve_timesteps( class OmniGenPipeline( DiffusionPipeline, - FromSingleFileMixin, - TextualInversionLoaderMixin, ): r""" The OmniGen pipeline for multimodal-to-image generation. @@ -161,7 +158,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) is not None else 8 ) # OmniGen latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this @@ -403,9 +400,6 @@ def __call__( prompt = [prompt] input_images = [input_images] - # using Float32 for the VAE doesn't take up much memory but can prevent potential black image outputs. - self.vae.to(torch.float32) - # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -426,7 +420,7 @@ def __call__( # 3. process multi-modal instructions if max_input_image_size != self.multimodal_processor.max_image_size: - self.multimodal_processor = OmniGenMultiModalProcessor(self.tokenizer, max_image_size=max_input_image_size) + self.multimodal_processor.reset_max_image_size(max_image_size=max_input_image_size) processed_data = self.multimodal_processor( prompt, input_images, diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py index d13e7a742379..9faf25629d9b 100644 --- a/src/diffusers/pipelines/omnigen/processor_omnigen.py +++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py @@ -62,6 +62,16 @@ def __init__(self, text_tokenizer, max_image_size: int = 1024): ) self.collator = OmniGenCollator() + + def reset_max_image_size(self, max_image_size): + self.max_image_size = max_image_size + self.image_transform = transforms.Compose( + [ + transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) def process_image(self, image): if isinstance(image, str): diff --git a/tests/models/transformers/test_models_transformer_omnigen.py b/tests/models/transformers/test_models_transformer_omnigen.py new file mode 100644 index 000000000000..cd77b2289074 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_omnigen.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import OmniGenTransformer2DModel +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + +class OmniGenTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = OmniGenTransformer2DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + height = 8 + width = 8 + sequence_length = 24 + + hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) + timestep = torch.rand(size=(batch_size,), dtype=hidden_states.dtype).to(torch_device) + input_ids = torch.randint(0, 10, (batch_size, sequence_length)).to(torch_device) + input_img_latents = [torch.randn((1, num_channels, height, width)).to(torch_device)] + input_image_sizes = {0: [[0, 0+height*width//2//2]]} + + attn_seq_length = sequence_length + 1 + height*width//2//2 + attention_mask = torch.ones((batch_size, attn_seq_length, attn_seq_length)).to(torch_device) + position_ids = torch.LongTensor([list(range(attn_seq_length))]*batch_size).to(torch_device) + + return { + "hidden_states": hidden_states, + "timestep": timestep, + "input_ids": input_ids, + "input_img_latents": input_img_latents, + "input_image_sizes": input_image_sizes, + "attention_mask": attention_mask, + "position_ids": position_ids, + } + + @property + def input_shape(self): + return (4, 8, 8) + + @property + def output_shape(self): + return (4, 8, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "hidden_size": 16, + "num_attention_heads": 4, + "num_key_value_heads": 4, + "intermediate_size": 32, + "num_layers": 1, + "pad_token_id": 0, + "vocab_size": 100, + "in_channels": 4, + "time_step_dim": 4, + "rope_scaling": {"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))} + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"OmniGenTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + From 4e9850a6afa09bcdc5eb653db84873f6c8960ef6 Mon Sep 17 00:00:00 2001 From: staoxiao <2906698981@qq.com> Date: Sat, 8 Feb 2025 19:45:18 +0800 Subject: [PATCH 51/55] update --- src/diffusers/pipelines/omnigen/pipeline_omnigen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 25ca933524b5..58b8256df7c0 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -361,7 +361,7 @@ def __call__( generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): + latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. From f91cfcf7c644dd727b625765d24baa808213deea Mon Sep 17 00:00:00 2001 From: staoxiao <2906698981@qq.com> Date: Mon, 10 Feb 2025 00:02:42 +0800 Subject: [PATCH 52/55] check_inputs --- src/diffusers/pipelines/omnigen/pipeline_omnigen.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 58b8256df7c0..1862777e7af5 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -199,6 +199,7 @@ def check_inputs( input_images, height, width, + use_input_image_size_as_output, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): @@ -219,6 +220,12 @@ def check_inputs( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) + if use_input_image_size_as_output: + if input_images is None or input_images[0] is None: + raise ValueError( + f"`use_input_image_size_as_output` is set to True, but no input image was found. If you are performing a text-to-image task, please set it to False." + ) + if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -406,6 +413,7 @@ def __call__( input_images, height, width, + use_input_image_size_as_output, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) From 711ddeddd1779a222843d15e20302e89cbed7203 Mon Sep 17 00:00:00 2001 From: shitao <2906698981@qq.com> Date: Tue, 11 Feb 2025 10:46:42 +0800 Subject: [PATCH 53/55] make style --- .../transformers/transformer_omnigen.py | 23 ++++++++++++------- .../pipelines/consisid/pipeline_consisid.py | 11 ++++++--- .../pipelines/omnigen/pipeline_omnigen.py | 7 +++--- .../pipelines/omnigen/processor_omnigen.py | 2 +- .../test_models_transformer_omnigen.py | 10 ++++---- 5 files changed, 32 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py index 3a02a21dcf11..0774a3f2a6ee 100644 --- a/src/diffusers/models/transformers/transformer_omnigen.py +++ b/src/diffusers/models/transformers/transformer_omnigen.py @@ -16,9 +16,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin @@ -91,7 +91,11 @@ def __init__( self.pos_embed_max_size = pos_embed_max_size pos_embed = get_2d_sincos_pos_embed( - embed_dim, self.pos_embed_max_size, base_size=base_size, interpolation_scale=self.interpolation_scale, output_type="pt" + embed_dim, + self.pos_embed_max_size, + base_size=base_size, + interpolation_scale=self.interpolation_scale, + output_type="pt", ) self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True) @@ -227,7 +231,7 @@ def apply_rotary_emb( Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ - + cos, sin = freqs_cis # [S, D] if len(cos.shape) == 2: cos = cos[None, None] @@ -241,10 +245,10 @@ def apply_rotary_emb( x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] x_rotated = torch.cat((-x2, x1), dim=-1) - + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out - + class OmniGenAttnProcessor2_0: r""" @@ -264,7 +268,6 @@ def __call__( attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - batch_size, sequence_length, _ = hidden_states.shape # Get Query-Key-Value Pair @@ -674,9 +677,13 @@ def forward( image_rotary_emb = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers: if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(decoder_layer, hidden_states, attention_mask, image_rotary_emb) + hidden_states = self._gradient_checkpointing_func( + decoder_layer, hidden_states, attention_mask, image_rotary_emb + ) else: - hidden_states = decoder_layer(hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb) + hidden_states = decoder_layer( + hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb + ) hidden_states = self.norm(hidden_states) diff --git a/src/diffusers/pipelines/consisid/pipeline_consisid.py b/src/diffusers/pipelines/consisid/pipeline_consisid.py index 0d4891cf17d7..1a99c2a0e9ee 100644 --- a/src/diffusers/pipelines/consisid/pipeline_consisid.py +++ b/src/diffusers/pipelines/consisid/pipeline_consisid.py @@ -48,9 +48,14 @@ >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview") - >>> face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = ( - ... prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16) - ... ) + >>> ( + ... face_helper_1, + ... face_helper_2, + ... face_clip_model, + ... face_main_model, + ... eva_transform_mean, + ... eva_transform_std, + ... ) = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16) >>> pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 1862777e7af5..faee7eaba691 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -16,7 +16,6 @@ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np -import PIL import torch from transformers import LlamaTokenizer @@ -223,9 +222,9 @@ def check_inputs( if use_input_image_size_as_output: if input_images is None or input_images[0] is None: raise ValueError( - f"`use_input_image_size_as_output` is set to True, but no input image was found. If you are performing a text-to-image task, please set it to False." - ) - + "`use_input_image_size_as_output` is set to True, but no input image was found. If you are performing a text-to-image task, please set it to False." + ) + if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py index 9faf25629d9b..ad107f662abc 100644 --- a/src/diffusers/pipelines/omnigen/processor_omnigen.py +++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py @@ -62,7 +62,7 @@ def __init__(self, text_tokenizer, max_image_size: int = 1024): ) self.collator = OmniGenCollator() - + def reset_max_image_size(self, max_image_size): self.max_image_size = max_image_size self.image_transform = transforms.Compose( diff --git a/tests/models/transformers/test_models_transformer_omnigen.py b/tests/models/transformers/test_models_transformer_omnigen.py index cd77b2289074..a7653f1f9d6d 100644 --- a/tests/models/transformers/test_models_transformer_omnigen.py +++ b/tests/models/transformers/test_models_transformer_omnigen.py @@ -25,6 +25,7 @@ enable_full_determinism() + class OmniGenTransformerTests(ModelTesterMixin, unittest.TestCase): model_class = OmniGenTransformer2DModel main_input_name = "hidden_states" @@ -42,11 +43,11 @@ def dummy_input(self): timestep = torch.rand(size=(batch_size,), dtype=hidden_states.dtype).to(torch_device) input_ids = torch.randint(0, 10, (batch_size, sequence_length)).to(torch_device) input_img_latents = [torch.randn((1, num_channels, height, width)).to(torch_device)] - input_image_sizes = {0: [[0, 0+height*width//2//2]]} + input_image_sizes = {0: [[0, 0 + height * width // 2 // 2]]} - attn_seq_length = sequence_length + 1 + height*width//2//2 + attn_seq_length = sequence_length + 1 + height * width // 2 // 2 attention_mask = torch.ones((batch_size, attn_seq_length, attn_seq_length)).to(torch_device) - position_ids = torch.LongTensor([list(range(attn_seq_length))]*batch_size).to(torch_device) + position_ids = torch.LongTensor([list(range(attn_seq_length))] * batch_size).to(torch_device) return { "hidden_states": hidden_states, @@ -77,7 +78,7 @@ def prepare_init_args_and_inputs_for_common(self): "vocab_size": 100, "in_channels": 4, "time_step_dim": 4, - "rope_scaling": {"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))} + "rope_scaling": {"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))}, } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -85,4 +86,3 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"OmniGenTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - From 565e51c1b777f8ac2ca441ba11d84596b5ca3c12 Mon Sep 17 00:00:00 2001 From: staoxiao <2906698981@qq.com> Date: Tue, 11 Feb 2025 20:54:04 +0800 Subject: [PATCH 54/55] update testpipeline --- .../pipelines/omnigen/pipeline_omnigen.py | 12 +- .../pipelines/omnigen/processor_omnigen.py | 22 ++-- src/diffusers/utils/testing_utils.py | 10 +- .../omnigen/test_pipeline_omnigen.py | 114 ++---------------- 4 files changed, 36 insertions(+), 122 deletions(-) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index faee7eaba691..41bfab5e3e04 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -139,7 +139,7 @@ class OmniGenPipeline( model_cpu_offload_seq = "transformer->vae" _optional_components = [] - _callback_tensor_inputs = ["latents", "input_images_latents"] + _callback_tensor_inputs = ["latents"] def __init__( self, @@ -435,6 +435,7 @@ def __call__( width=width, use_img_cfg=use_img_cfg, use_input_image_size_as_output=use_input_image_size_as_output, + num_images_per_prompt=num_images_per_prompt, ) processed_data["input_ids"] = processed_data["input_ids"].to(device) processed_data["attention_mask"] = processed_data["attention_mask"].to(device) @@ -448,6 +449,7 @@ def __call__( timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas ) + self._num_timesteps = len(timesteps) # 6. Prepare latents. if use_input_image_size_as_output: @@ -496,6 +498,14 @@ def __call__( latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py index ad107f662abc..75d272ac5140 100644 --- a/src/diffusers/pipelines/omnigen/processor_omnigen.py +++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py @@ -135,6 +135,7 @@ def __call__( use_img_cfg: bool = True, separate_cfg_input: bool = False, use_input_image_size_as_output: bool = False, + num_images_per_prompt: int = 1, ) -> Dict: if isinstance(instructions, str): instructions = [instructions] @@ -161,17 +162,18 @@ def __call__( else: img_cfg_mllm_input = neg_mllm_input - if use_input_image_size_as_output: - input_data.append( - ( - mllm_input, - neg_mllm_input, - img_cfg_mllm_input, - [mllm_input["pixel_values"][0].size(-2), mllm_input["pixel_values"][0].size(-1)], + for _ in range(num_images_per_prompt): + if use_input_image_size_as_output: + input_data.append( + ( + mllm_input, + neg_mllm_input, + img_cfg_mllm_input, + [mllm_input["pixel_values"][0].size(-2), mllm_input["pixel_values"][0].size(-1)], + ) ) - ) - else: - input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width])) + else: + input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width])) return self.collator(input_data) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 7eda13716025..223fa19f656e 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1077,28 +1077,28 @@ def _is_torch_fp64_available(device): # Function definitions BACKEND_EMPTY_CACHE = { "cuda": torch.cuda.empty_cache, - "xpu": torch.xpu.empty_cache, + # "xpu": torch.xpu.empty_cache, "cpu": None, "mps": torch.mps.empty_cache, "default": None, } BACKEND_DEVICE_COUNT = { "cuda": torch.cuda.device_count, - "xpu": torch.xpu.device_count, + # "xpu": torch.xpu.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0, } BACKEND_MANUAL_SEED = { "cuda": torch.cuda.manual_seed, - "xpu": torch.xpu.manual_seed, + # "xpu": torch.xpu.manual_seed, "cpu": torch.manual_seed, "mps": torch.mps.manual_seed, "default": torch.manual_seed, } BACKEND_RESET_PEAK_MEMORY_STATS = { "cuda": torch.cuda.reset_peak_memory_stats, - "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), + # "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, "default": None, @@ -1112,7 +1112,7 @@ def _is_torch_fp64_available(device): } BACKEND_MAX_MEMORY_ALLOCATED = { "cuda": torch.cuda.max_memory_allocated, - "xpu": getattr(torch.xpu, "max_memory_allocated", None), + # "xpu": getattr(torch.xpu, "max_memory_allocated", None), "cpu": 0, "mps": 0, "default": 0, diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py index 70e9cd0a53b2..dd5e5fcb2918 100644 --- a/tests/pipelines/omnigen/test_pipeline_omnigen.py +++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py @@ -34,112 +34,14 @@ def get_dummy_components(self): torch.manual_seed(0) transformer = OmniGenTransformer2DModel( - rope_scaling={ - "long_factor": [ - 1.0299999713897705, - 1.0499999523162842, - 1.0499999523162842, - 1.0799999237060547, - 1.2299998998641968, - 1.2299998998641968, - 1.2999999523162842, - 1.4499999284744263, - 1.5999999046325684, - 1.6499998569488525, - 1.8999998569488525, - 2.859999895095825, - 3.68999981880188, - 5.419999599456787, - 5.489999771118164, - 5.489999771118164, - 9.09000015258789, - 11.579999923706055, - 15.65999984741211, - 15.769999504089355, - 15.789999961853027, - 18.360000610351562, - 21.989999771118164, - 23.079999923706055, - 30.009998321533203, - 32.35000228881836, - 32.590003967285156, - 35.56000518798828, - 39.95000457763672, - 53.840003967285156, - 56.20000457763672, - 57.95000457763672, - 59.29000473022461, - 59.77000427246094, - 59.920005798339844, - 61.190006256103516, - 61.96000671386719, - 62.50000762939453, - 63.3700065612793, - 63.48000717163086, - 63.48000717163086, - 63.66000747680664, - 63.850006103515625, - 64.08000946044922, - 64.760009765625, - 64.80001068115234, - 64.81001281738281, - 64.81001281738281, - ], - "short_factor": [ - 1.05, - 1.05, - 1.05, - 1.1, - 1.1, - 1.1, - 1.2500000000000002, - 1.2500000000000002, - 1.4000000000000004, - 1.4500000000000004, - 1.5500000000000005, - 1.8500000000000008, - 1.9000000000000008, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.000000000000001, - 2.1000000000000005, - 2.1000000000000005, - 2.2, - 2.3499999999999996, - 2.3499999999999996, - 2.3499999999999996, - 2.3499999999999996, - 2.3999999999999995, - 2.3999999999999995, - 2.6499999999999986, - 2.6999999999999984, - 2.8999999999999977, - 2.9499999999999975, - 3.049999999999997, - 3.049999999999997, - 3.049999999999997, - ], - "type": "su", - }, - patch_size=2, + hidden_size=16, + num_attention_heads=4, + num_key_value_heads=4, + intermediate_size=32, + num_layers=1, in_channels=4, - pos_embed_max_size=192, + time_step_dim=4, + rope_scaling={"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))}, ) torch.manual_seed(0) @@ -174,7 +76,7 @@ def get_dummy_inputs(self, device, seed=0): inputs = { "prompt": "A painting of a squirrel eating a burger", "generator": generator, - "num_inference_steps": 2, + "num_inference_steps": 1, "guidance_scale": 3.0, "output_type": "np", "height": 16, From 29ad6ae64905704e1cbff218d30e71c964c0daa8 Mon Sep 17 00:00:00 2001 From: staoxiao <2906698981@qq.com> Date: Tue, 11 Feb 2025 20:55:49 +0800 Subject: [PATCH 55/55] update testpipeline --- src/diffusers/utils/testing_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 223fa19f656e..7eda13716025 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1077,28 +1077,28 @@ def _is_torch_fp64_available(device): # Function definitions BACKEND_EMPTY_CACHE = { "cuda": torch.cuda.empty_cache, - # "xpu": torch.xpu.empty_cache, + "xpu": torch.xpu.empty_cache, "cpu": None, "mps": torch.mps.empty_cache, "default": None, } BACKEND_DEVICE_COUNT = { "cuda": torch.cuda.device_count, - # "xpu": torch.xpu.device_count, + "xpu": torch.xpu.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0, } BACKEND_MANUAL_SEED = { "cuda": torch.cuda.manual_seed, - # "xpu": torch.xpu.manual_seed, + "xpu": torch.xpu.manual_seed, "cpu": torch.manual_seed, "mps": torch.mps.manual_seed, "default": torch.manual_seed, } BACKEND_RESET_PEAK_MEMORY_STATS = { "cuda": torch.cuda.reset_peak_memory_stats, - # "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), + "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, "default": None, @@ -1112,7 +1112,7 @@ def _is_torch_fp64_available(device): } BACKEND_MAX_MEMORY_ALLOCATED = { "cuda": torch.cuda.max_memory_allocated, - # "xpu": getattr(torch.xpu, "max_memory_allocated", None), + "xpu": getattr(torch.xpu, "max_memory_allocated", None), "cpu": 0, "mps": 0, "default": 0,