Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The Modular Diffusers #9672

Open
wants to merge 68 commits into
base: main
Choose a base branch
from
Open

The Modular Diffusers #9672

wants to merge 68 commits into from

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Oct 14, 2024

Getting Started with Modular Diffusers

With Modular Diffusers, we introduce a unified pipeline system that simplifies how you work with diffusion models. Instead of creating separate pipelines for each task, Modular Diffusers let you:

Write Only What's New: You won't need to rewrite the entire pipeline from scratch. You can create pipeline blocks just for your new workflow's unique aspects and reuse existing blocks for existing functionalities.

Assemble Like LEGO®: You can mix and match blocks in flexible ways. This allows you to write dedicated blocks for specific workflows, and then assemble different blocks into a pipeline that that can be used more conveniently for multiple workflows. Here we will walk you through how to use a pipeline like this we built with Modular diffusers! In later sections, we will also go over how to assemble and build new pipelines!

Quick Start with StableDiffusionXLAutoPipeline

from diffusers import ModularPipeline, StableDiffusionXLAutoPipeline
from diffusers.pipelines.components_manager import ComponentsManager

# Load models
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
components.enable_auto_cpu_offload(device="cuda:0")

# Create pipeline
auto_pipe = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
auto_pipe.update_states(**components.components)

Auto Workflow Selection

The pipeline automatically adapts to your inputs:

  • Basic text-to-image: Just provide a prompt
  • Image-to-image: Add an image input
  • Inpainting: Add both image and mask_image
  • ControlNet: Add a control_image
  • And more!

Auto Documentations

We care a great deal about documentation here at Diffusers, and Modular Diffusers carries this mission forward. All our pipeline blocks comes with complete docstrings that automatically compose as you build your pipelines. This means

  • Every pipeline you build with Modular diffusers come with complete documentation automatically
  • Input/output signatures are dynamically generated, same goes for components and configurations
  • Parameter descriptions and types are included
  • Block relationships and dependencies are documented as well

inspect your pipeline

# get pipeline info components/configurations/pipeline blocks/ docstring
print(auto_pipe)
see an example of output
ModularPipeline:
==============================

Pipeline Block:
--------------
StableDiffusionXLAutoPipeline
 (Class: SequentialPipelineBlocks)
  • text_encoder (StableDiffusionXLTextEncoderStep)
  • ip_adapter (StableDiffusionXLAutoIPAdapterStep)
  • image_encoder (StableDiffusionXLAutoVaeEncoderStep)
  • before_denoise (StableDiffusionXLAutoBeforeDenoiseStep)
  • denoise (StableDiffusionXLAutoDenoiseStep)
  • decode (StableDiffusionXLAutoDecodeStep)

Registered Components:
----------------------
text_encoder: CLIPTextModel (dtype=torch.float16, device=cpu)
text_encoder_2: CLIPTextModelWithProjection (dtype=torch.float16, device=cpu)
tokenizer: CLIPTokenizer
tokenizer_2: CLIPTokenizer
image_encoder: CLIPVisionModelWithProjection (dtype=torch.float16, device=cpu)
feature_extractor: CLIPImageProcessor
unet: UNet2DConditionModel (dtype=torch.float16, device=cpu)
vae: AutoencoderKL (dtype=torch.float16, device=cpu)
scheduler: EulerDiscreteScheduler
controlnet: ControlNetModel (dtype=torch.float16, device=cpu)
guider: CFGGuider
controlnet_guider: CFGGuider

Registered Configs:
------------------
force_zeros_for_empty_prompt: True
requires_aesthetics_score: False

------------------
This pipeline contains blocks that are selected at runtime based on inputs.

Trigger Inputs: {'control_image', 'control_mode', 'image_latents', 'padding_mask_crop', 'mask_image', 'ip_adapter_image', 'image', 'mask'}
  Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('control_image')`).
Check `.doc` of returned object for more information.

  Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.
  - for image-to-image generation, you need to provide either `image` or `image_latents`
  - for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` 
  - to run the controlnet workflow, you need to provide `control_image`
  - to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`
  - to run the ip_adapter workflow, you need to provide `ip_adapter_image`
  - for text-to-image generation, all you need to provide is `prompt`

  Args:

      prompt (`Union[str, List]`, *optional*):
          The prompt or prompts to guide the image generation.

      prompt_2 (`Union[str, List]`, *optional*):
          The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in
          both text-encoders

      negative_prompt (`Union[str, List]`, *optional*):
          The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if
          `guidance_scale` is less than `1`).

      negative_prompt_2 (`Union[str, List]`, *optional*):
          The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not
          defined, `negative_prompt` is used in both text-encoders

      cross_attention_kwargs (`Union[dict, NoneType]`, *optional*):
          A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor`
          in [diffusers.models.attention_processor]

      guidance_scale (`float`, *optional*, defaults to 5.0):
          Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
          `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance
          scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are
          closely linked to the text `prompt`, usually at the expense of lower image quality.

      clip_skip (`Union[int, NoneType]`, *optional*):

      ip_adapter_image (`Union[Image, ndarray, Tensor, List, List, List]`):
          The image(s) to be used as ip adapter

      height (`Union[int, NoneType]`, *optional*):
          The height in pixels of the generated image. This is set to 1024 by default for the best results. Anything below
          512 pixels won't work well for
          [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and
          checkpoints that are not specifically fine-tuned on low resolutions.

      width (`Union[int, NoneType]`, *optional*):
          The width in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512
          pixels won't work well for
          [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and
          checkpoints that are not specifically fine-tuned on low resolutions.

      generator (`Union[Generator, List, NoneType]`, *optional*):
          One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
          generation deterministic.

      image (`Union[Image, ndarray, Tensor, List, List, List]`):
          The image(s) to modify with the pipeline, for img2img or inpainting task. When using for inpainting task, parts of
          the image will be masked out with `mask_image` and repainted according to `prompt`.

      mask_image (`Union[Image, ndarray, Tensor, List, List, List]`, *optional*):
          `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while
          black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel
          (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected
          shape would be `(B, H, W, 1)`.

      padding_mask_crop (`Union[Tuple, NoneType]`, *optional*):
          The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and
          mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect
          ratio of the image and contains all masked area, and then expand that area based on `padding_mask_crop`. The image
          and mask_image will then be cropped based on the expanded area before resizing to the original image size for
          inpainting. This is useful when the masked area is small while the image is large and contain information
          irrelevant for inpainting, such as background.

      num_images_per_prompt (`int`, *optional*, defaults to 1):
          The number of images to generate per prompt.

      num_inference_steps (`int`, *optional*, defaults to 50):
          The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower
          inference.

      timesteps (`Union[Tensor, NoneType]`, *optional*):
          Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their
          `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used.
          Must be in descending order.

      sigmas (`Union[Tensor, NoneType]`, *optional*):
          Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their
          `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used.

      denoising_end (`Union[float, NoneType]`, *optional*):
          When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before
          it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount
          of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should
          ideally be utilized when this pipeline forms a part of a 'Mixture of Denoisers' multi-pipeline setup.

      strength (`float`, *optional*, defaults to 0.3):
          Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting).
          Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
          `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1,
          added noise will be maximum and the denoising process will run for the full number of iterations specified in
          `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of
          `denoising_start` being declared as an integer, the value of `strength` will be ignored.

      denoising_start (`Union[float, NoneType]`, *optional*):
          The denoising start value to use for the scheduler. Determines the starting point of the denoising process.

      latents (`Union[Tensor, NoneType]`, *optional*):
          Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can
          be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by
          sampling using the supplied random `generator`.

      original_size (`Tuple`, *optional*, defaults to (1024, 1024)):
          The original size (height, width) of the image that conditions the generation process. If different from
          target_size, the image will appear to be down- or upsampled. Part of SDXL's micro-conditioning as explained in
          section 2.2 of https://huggingface.co/papers/2307.01952

      target_size (`Tuple`, *optional*, defaults to (1024, 1024)):
          The target size (height, width) of the generated image. For most cases, this should be set to the desired output
          dimensions. Part of SDXL's micro-conditioning as explained in section 2.2 of
          https://huggingface.co/papers/2307.01952

      negative_original_size (`Tuple`, *optional*, defaults to (1024, 1024)):
          The negative original size to condition against during generation. Part of SDXL's micro-conditioning as explained
          in section 2.2 of https://huggingface.co/papers/2307.01952. See:
          https://github.com/huggingface/diffusers/issues/4208

      negative_target_size (`Tuple`, *optional*, defaults to (1024, 1024)):
          The negative target size to condition against during generation. Should typically match target_size. Part of SDXL's
          micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. See:
          https://github.com/huggingface/diffusers/issues/4208

      crops_coords_top_left (`Tuple`, *optional*, defaults to (0, 0)):
          `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
          `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning

      negative_crops_coords_top_left (`Tuple`, *optional*, defaults to (0, 0)):
          To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
          micro-conditioning

      aesthetic_score (`float`, *optional*, defaults to 6.0):
          Used to simulate an aesthetic score of the generated image by influencing the positive text condition. Part of
          SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952

      negative_aesthetic_score (`float`, *optional*, defaults to 2.0):
          Part of SDXL's micro-conditioning as explained in section 2.2 of https://huggingface.co/papers/2307.01952. Can be
          used to simulate an aesthetic score of the generated image by influencing the negative text condition.

      control_image (`Union[Image, ndarray, Tensor, List, List, List]`, *optional*):
          The ControlNet input condition to provide guidance to the unet for generation. If passed as torch.Tensor, it is
          used as-is. PIL.Image.Image inputs are accepted and default to image dimensions. For multiple ControlNets, pass
          images as a list for proper batching.

      control_guidance_start (`Union[float, List]`, *optional*, defaults to 0.0):
          The percentage of total steps at which the ControlNet starts applying.

      control_guidance_end (`Union[float, List]`, *optional*, defaults to 1.0):
          The percentage of total steps at which the ControlNet stops applying.

      control_mode (`List`, *optional*):
          The control mode for union controlnet, 0 for openpose, 1 for depth, 2 for hed/pidi/scribble/ted, 3 for
          canny/lineart/anime_lineart/mlsd, 4 for normal and 5 for segment

      controlnet_conditioning_scale (`Union[float, List]`, *optional*, defaults to 1.0):
          Scale factor for ControlNet outputs before adding to unet residual. For multiple ControlNets, can be set as a list
          of scales.

      guess_mode (`bool`, *optional*, defaults to False):
          Enables ControlNet encoder to recognize input image content without prompts. Recommended guidance_scale: 3.0-5.0.

      guidance_rescale (`float`, *optional*, defaults to 0.0):
          Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion
          Noise Schedules and Sample Steps are Flawed'.

      eta (`float`, *optional*, defaults to 0.0):
          Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others.

      guider_kwargs (`Union[Dict, NoneType]`, *optional*):
          Optional kwargs dictionary passed to the Guider.

      output_type (`str`, *optional*, defaults to pil):
          The output format of the generated image. Choose between PIL (PIL.Image.Image), torch.Tensor or np.array.

      return_dict (`bool`, *optional*, defaults to True):
          Whether or not to return a StableDiffusionXLPipelineOutput instead of a plain tuple.

      dtype (`dtype`, *optional*):
          The dtype of the model inputs

      preprocess_kwargs (`Union[dict, NoneType]`, *optional*):
          A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under
          `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]

      ip_adapter_embeds (`List`, *optional*):
          Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step.

      negative_ip_adapter_embeds (`List`, *optional*):
          Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step.

      image_latents (`Tensor`, *optional*):
          The latents representing the reference image for image-to-image/inpainting generation. Can be generated in
          vae_encode step.

      mask (`Tensor`, *optional*):
          The mask for the inpainting generation. Can be generated in vae_encode step.

      masked_image_latents (`Tensor`, *optional*):
          The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in
          vae_encode step.

      image_latents (`Union[Tensor, NoneType]`, *optional*):
          The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in
          vae_encode or prepare_latent step.

      crops_coords (`Union[Tuple, NoneType]`, *optional*):
          The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be
          generated in vae_encode or prepare_latent step.

      crops_coords (`Tuple`, *optional*):
          The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be
          generated in vae_encode step.

  Returns:

      images (`Union[List, List, List]`):
          The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array

use get_execution_blocks to see which blocks will run for your inputs/workflow, for example, if you want to run a text-to-image controlnet workflow, you can do this

print(auto_pipe.get_execution_blocks("control_image"))

see the docstring relevant to your inputs/workflow

print(auto_pipe.get_execution_blocks("control_image").doc)

Advanced Workflows

Once you've created the auto pipeline, you can use it for different features as long as you add the required components and pass the required inputs.

# Add ControlNet
auto_pipe.update_states(controlnet=controlnet)

# Enable IP-Adapter
auto_pipe.update_states(image_encoder=..., feature_extractor=...)
auto_pipe.load_ip_adapter("h94/IP-Adapter")

# Add LoRA
auto_pipe.load_lora_weights(...)

# at inference time, pass all the inputs required for your workflow
images = auto_pipe(
    prompt="..",
    control_image=pose_image,        # this trigger the ControlNet workflow
    ip_adapter_image=style_image,    # this trigger the ip-adapter workflow
    ...
).images

Here is an example you can run for a more complex workflow using controlnet/IP-Adapter/Lora/PAG

from diffusers import ControlNetModel
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from diffusers.utils import load_image
from diffusers.guider import PAGGuider

# load controlnet
controlnet = ControlNetModel.from_pretrained("thibaud/controlnet-openpose-sdxl-1.0", torch_dtype=dtype)
components.add("controlnet", controlnet)

# load image_encoder for ip adapter
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="sdxl_models/image_encoder", torch_dtype=dtype)
feature_extractor = CLIPImageProcessor(size=224, crop_size=224)

components.add("image_encoder", image_encoder)
components.add("feature_extractor", feature_extractor)

# load additional components into the pipeline
auto_pipe.update_states(**components.get(["controlnet", "image_encoder", "feature_extractor"]))

# load ip adapter
auto_pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
auto_pipe.set_ip_adapter_scale(0.6)

# let's also load a lora while we're at it
auto_pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face")

# let's also throw PAG in there because why not!
pag_guider = PAGGuider(pag_applied_layers="mid")
controlnet_pag_guider = PAGGuider(pag_applied_layers="mid")
auto_pipe.update_states(guider=pag_guider, controlnet_guider=controlnet_pag_guider)

# prepare inputs
prompt = "an astronaut"
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/controlnet/person_pose.png")
ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png")

# Run pipeline with everything combined
images = auto_pipe(
    prompt=prompt,
    control_image=control_image,
    ip_adapter_image=ip_adapter_image,
    output="images"
).images
images[0]

yiyi_modular_out

check out more usage examples here

test1: complete testing script for `StableDiffusionXLAutoPipeline`
import os
import torch
import numpy as np
import cv2
from PIL import Image

from diffusers import (
    ControlNetModel,
    ModularPipeline,
    UNet2DConditionModel,
    AutoencoderKL,
    ControlNetUnionModel,
)
from diffusers.utils import load_image
from diffusers.guider import PAGGuider, CFGGuider, APGGuider
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import StableDiffusionXLAutoPipeline, StableDiffusionXLIPAdapterStep

from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor

from controlnet_aux import LineartAnimeDetector

import logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("diffusers").setLevel(logging.INFO)


# define device and dtype
device = "cuda:0"
dtype = torch.float16
num_images_per_prompt = 1

test_pag = True
test_lora = False


# define output folder
out_folder = "modular_test_outputs_0131_auto_pipeline"
if os.path.exists(out_folder):
    # Remove all files in the directory
    for file in os.listdir(out_folder):
        file_path = os.path.join(out_folder, file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print(f"Error: {e}")
os.makedirs(out_folder, exist_ok=True)

# functions for memory info
def reset_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def clear_memory():
    torch.cuda.empty_cache()

def print_mem(mem_size, name):
    mem_gb = mem_size / 1024**3
    mem_mb = mem_size / 1024**2
    print(f"- {name}: {mem_gb:.2f} GB ({mem_mb:.2f} MB)")

def print_memory(message=None):
    """
    Print detailed GPU memory statistics for a specific device.
    
    Args:
        device_id (int): GPU device ID
    """
    allocated_mem = torch.cuda.memory_allocated(device)
    reserved_mem = torch.cuda.memory_reserved(device)
    mem_on_device = torch.cuda.mem_get_info(device)[0]
    peak_mem = torch.cuda.max_memory_allocated(device)

    print(f"\nGPU:{device} Memory Status {message}:")
    print_mem(allocated_mem, "allocated memory")
    print_mem(reserved_mem, "reserved memory")
    print_mem(peak_mem, "peak memory")
    print_mem(mem_on_device, "mem on device")

# function to make canny image (for controlnet)
def make_canny(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    return Image.fromarray(image)


# (1)Define inputs
# for text2img/img2img
prompt = "a bear sitting in a chair drinking a milkshake"
negative_prompt = "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality"

# for img2img
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
init_image = load_image(url).convert("RGB")
strength = 0.9 

# for controlnet
control_image = make_canny(init_image)
controlnet_conditioning_scale = 0.5  # recommended for good generalization
# for controlnet_union
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
controlnet_union_image = processor(init_image, output_type="pil")

# for inpainting
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

inpaint_image = load_image(img_url).resize((1024, 1024))
inpaint_mask = load_image(mask_url).resize((1024, 1024))
inpaint_control_image = make_canny(inpaint_image)
inpaint_strength = 0.99

# for ip adapter
ip_adapter_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")


# (2) define blocks and nodes(builder)      

auto_pipeline_block = StableDiffusionXLAutoPipeline()
auto_pipeline = ModularPipeline.from_block(auto_pipeline_block)
refiner_pipeline = ModularPipeline.from_block(auto_pipeline_block)



# (3) add states to nodes
repo = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_repo = "stabilityai/stable-diffusion-xl-refiner-1.0"
controlnet_repo = "diffusers/controlnet-canny-sdxl-1.0"
inpaint_repo = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
vae_fix_repo = "madebyollin/sdxl-vae-fp16-fix"
controlnet_union_repo = "brad-twinkl/controlnet-union-sdxl-1.0-promax"
ip_adapter_repo = "h94/IP-Adapter"


components = ComponentsManager()
components.add_from_pretrained(repo, torch_dtype=dtype)


controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype)
components.add("controlnet", controlnet)

image_encoder = CLIPVisionModelWithProjection.from_pretrained(ip_adapter_repo, subfolder="sdxl_models/image_encoder", torch_dtype=dtype)
feature_extractor = CLIPImageProcessor(size=224, crop_size=224)

components.add("image_encoder", image_encoder)
components.add("feature_extractor", feature_extractor)


# load components/config into nodes
auto_pipeline.update_states(**components.components)


# load other componetns for swap later
refiner_unet = UNet2DConditionModel.from_pretrained(refiner_repo, subfolder="unet", torch_dtype=dtype)
inpaint_unet = UNet2DConditionModel.from_pretrained(inpaint_repo, subfolder="unet", torch_dtype=dtype)
vae_fix = AutoencoderKL.from_pretrained(vae_fix_repo, torch_dtype=dtype)
controlnet_union = ControlNetUnionModel.from_pretrained(controlnet_union_repo, torch_dtype=dtype)

components.add("refiner_unet", refiner_unet)
components.add("inpaint_unet", inpaint_unet)
components.add("controlnet_union", controlnet_union)
components.add("vae_fix", vae_fix)


# you can add guiders to manager too but no need because it was not serialized
pag_guider = PAGGuider(pag_applied_layers="mid")
controlnet_pag_guider = PAGGuider(pag_applied_layers="mid")
cfg_guider = CFGGuider()
controlnet_cfg_guider = CFGGuider()


# (4) enable auto cpu offload: automatically offload models when available gpu memory go below a certain threshold
components.enable_auto_cpu_offload(device=device)
print(components)
reset_memory()



# using auto_pipeline to generate images

# to get info about auto_pipeline and how to use it: inputs/outputs/components
# this is an "auto" workflow that works for all use cases: text2img, img2img, inpainting, controlnet, etc.
print(f" ")
print(f" auto_pipeline:")
print(auto_pipeline)
print(" ")


# since we want to use text2img use case, we can run the following to see components/blocks/inputs for this use case
print(f" ")
print(f" auto_pipeline info (default use case: text2img)")
print(auto_pipeline.get_execution_blocks())
print(" ")

# test1: text2img use case
# when you run the auto workflow, you will get these logs telling you which blocks are actuallyrunning
# (should match what the sdxl_node told you)
# Running block: StableDiffusionXLBeforeDenoiseStep, trigger: None
# Running block: StableDiffusionXLDenoiseStep, trigger: None
# Running block: StableDiffusionXLDecodeStep, trigger: None

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test1_out_text2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test1_out_text2img.png")

clear_memory()


# test2: text2img with lora use case
print(f" ")
print(f" running test2: text2img with lora use case")
generator = torch.Generator(device="cuda").manual_seed(0)
auto_pipeline.load_lora_weights("rajkumaralma/dissolve_dust_style", weight_name="ral-dissolve-sdxl.safetensors", adapter_name="ral-dissolve")
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test2_out_text2img_lora_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test2_out_text2img_lora.png")


# test3:text2image with pag
print(f" ")
print(f" running test3:text2image with pag")
if not test_lora:
    auto_pipeline.unload_lora_weights()

auto_pipeline.update_states(guider=pag_guider, controlnet_guider=controlnet_pag_guider)
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    guider_kwargs={"pag_scale": 3.0},
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test3_out_text2img_pag_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test3_out_text2img_pag.png")

clear_memory()
# checkout the components if you want, the models used is moved to devicem some might get offloaded to cpu
# print(components)


# test4: SDXL(text2img) with ip_adapter+ pag?
print(f" ")
print(f" running test4: SDXL(text2img) with ip_adapter")

auto_pipeline.load_ip_adapter(ip_adapter_repo, subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
auto_pipeline.set_ip_adapter_scale(0.6)

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    ip_adapter_image=ip_adapter_image,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test  4_out_text2img_ip_adapter_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test4_out_text2img_ip_adapter.png")

auto_pipeline.unload_ip_adapter()
clear_memory()

# test5: SDXL(text2img) with controlnet

if not test_pag:
    auto_pipeline.update_states(guider=cfg_guider, controlnet_guider=controlnet_cfg_guider)
    guider_kwargs = {}
else:
    guider_kwargs = {"pag_scale": 3.0}


# we are going to pass a new input now `control_image` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" auto_pipeline info (controlnet use case)")
print(auto_pipeline.get_execution_blocks("control_image"))
print(" ")

print(f" ")
print(f" running test5: SDXL(text2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test5_out_text2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test5_out_text2img_control.png")

clear_memory()

# test6: SDXL(img2img)

print(f" ")
print(f" running test6: SDXL(img2img)")

generator = torch.Generator(device="cuda").manual_seed(0)

# let's checkout the sdxl_node info for img2img use case
print(f" auto_pipeline info (img2img use case)")
print(auto_pipeline.get_execution_blocks("image"))
print(" ")

images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    generator=generator, 
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test6_out_img2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test6_out_img2img.png")

clear_memory()


# test7: SDXL(img2img) with controlnet
# let's checkout the sdxl_node info for img2img controlnet use case
print(f" sdxl_node info (img2img controlnet use case)")
print(auto_pipeline.get_execution_blocks("image", "control_image"))
print(" ")

print(f" ")
print(f" running test7: SDXL(img2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    generator=generator, 
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test7_out_img2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test7_out_img2img_control.png")

clear_memory()

# test8: img2img with refiner

refiner_pipeline.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)
# let's checkout the refiner_node
print(f" refiner_pipeline info")
print(refiner_pipeline)
print(f" ")

print(f" refiner_pipeline: triggered by `image_latents`")
print(refiner_pipeline.get_execution_blocks("image_latents"))
print(" ")

print(f" running test8: img2img with refiner")


generator = torch.Generator(device="cuda").manual_seed(0)
latents = auto_pipeline(
    prompt=prompt, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)
images_output = refiner_pipeline(
    image_latents=latents,  
    prompt=prompt, 
    denoising_start=0.8, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test8_out_img2img_refiner_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test8_out_img2img_refiner.png")

clear_memory()

# test9: SDXL(inpainting)
# let's checkout the sdxl_node info for inpainting use case
print(f" auto_pipeline info (inpainting use case)")
print(auto_pipeline.get_execution_blocks("mask_image", "image"))
print(" ")

print(f" ") 
print(f" running test9: SDXL(inpainting)")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=inpaint_image, 
    mask_image=inpaint_mask, 
    height=1024, 
    width=1024, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test9_out_inpainting_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test9_out_inpainting.png")

clear_memory()

# test10: SDXL(inpainting) with controlnet
# let's checkout the sdxl_node info for inpainting + controlnet use case
print(f" auto_pipeline info (inpainting + controlnet use case)")
print(auto_pipeline.get_execution_blocks("mask_image", "control_image"))
print(" ")

print(f" ") 
print(f" running test10: SDXL(inpainting) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    control_image=control_image, 
    image=init_image,
    height=1024,
    width=1024,
    mask_image=inpaint_mask, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test10_out_inpainting_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test10_out_inpainting_control.png")

clear_memory()

# test11: SDXL(inpainting) with inpaint_unet
print(f" ") 
print(f" running test11: SDXL(inpainting) with inpaint_unet")

auto_pipeline.update_states(unet=components.get("inpaint_unet"))
generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=inpaint_image, 
    mask_image=inpaint_mask, 
    height=1024, 
    width=1024, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test11_out_inpainting_inpaint_unet_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test11_out_inpainting_inpaint_unet.png")

clear_memory()


# test12: SDXL(inpainting) with inpaint_unet + padding_mask_crop
print(f" ") 
print(f" running test12: SDXL(inpainting) with inpaint_unet (padding_mask_crop=33)")

generator = torch.Generator(device="cuda").manual_seed(0)

images_output = auto_pipeline(
    prompt=prompt, 
    image=inpaint_image, 
    mask_image=inpaint_mask, 
    height=1024, 
    width=1024, 
    generator=generator, 
    padding_mask_crop=33, 
    num_images_per_prompt=num_images_per_prompt,
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test12_out_inpainting_inpaint_unet_padding_mask_crop_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test12_out_inpainting_inpaint_unet_padding_mask_crop.png")

clear_memory()


# test13: apg

print(f" ")
print(f" running test13: apg")

apg_guider = APGGuider()
auto_pipeline.update_states(guider=apg_guider, unet=components.get("unet"))


generator = torch.Generator().manual_seed(0)
images_output = auto_pipeline(
  prompt=prompt, 
  generator=generator,
  num_inference_steps=20,
  num_images_per_prompt=1, # yiyi: apg does not work with num_images_per_prompt > 1
  guidance_scale=15,
  height=896,
  width=768,
  guider_kwargs={
      "adaptive_projected_guidance_momentum": -0.5,
      "adaptive_projected_guidance_rescale_factor": 15.0,
  },
  output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test13_out_apg_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test13_out_apg.png")

clear_memory()


# test13: SDXL(text2img) with controlnet_union

auto_pipeline.update_states(controlnet=components.get("controlnet_union"), unet=components.get("unet"), vae=components.get("vae_fix"), guider=pag_guider, controlnet_guider=controlnet_pag_guider)
# we are going to pass a new input now `control_mode` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" auto_pipeline info (controlnet union use case)")
print(auto_pipeline.get_execution_blocks("control_mode"))
print(" ")

print(f" ")
print(f" running test14: SDXL(text2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)

images_output = auto_pipeline(
    prompt=prompt, 
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test14_out_text2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test14_out_text2img_control_union.png")

clear_memory()


# test14: SDXL(img2img) with controlnet_union

print(f" ")
print(f" auto_pipeline info (img2img controlnet union use case)")
print(auto_pipeline.get_execution_blocks("image", "control_mode"))
print(" ")

print(f" ")
print(f" running test15: SDXL(img2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    generator=generator, 
    control_mode=[3], 
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt, 
    height=1024, 
    width=1024, 
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test15_out_img2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test15_out_img2img_control_union.png")

clear_memory()

# test15: SDXL(inpainting) with controlnet_union
print(f" ")
print(f" auto_pipeline info (inpainting controlnet union use case)")
print(auto_pipeline.get_execution_blocks("mask", "control_mode"))
print(" ")

print(f" ")
print(f" running test16: SDXL(inpainting) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
images_output = auto_pipeline(
    prompt=prompt, 
    image=init_image, 
    mask_image=inpaint_mask, 
    control_image=controlnet_union_image,
    control_mode=[3],
    height=1024, 
    width=1024, 
    generator=generator, 
    output="images"
)

for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test16_out_inpainting_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test16_out_inpainting_control_union.png")

clear_memory()

print_memory("the end")

print(f" components info after the end")
print(components)

Modular Setup

StableDiffusionXLAutoPipeline is a very convenient preset; Just like the LEGO sets, you can break it down and reassemble and rearrange the pipeline blocks however you want. A more modular setup would look like this:

# AUTOBLOCK is a map of all the blocks we used to assemble `StableDiffusionXLAutoPipeline`
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import AUTO_BLOCKS


# step1: create separate nodes to encode text/image/ip-adapter inputs
text_node = ModularPipeline.from_block(AUTO_BLOCKS.pop("text_encoder")()) 
image_node = ModularPipeline.from_block(AUTO_BLOCKS.pop("image_encoder")()) 
decoder_node = ModularPipeline.from_block(AUTO_BLOCKS.pop("decode")()) 

# make a node for "denoising", here we just use the leftover blocks
class SDXLAutoBlocks(SequentialPipelineBlocks):
    block_classes = list(AUTO_BLOCKS.values())
    block_names = list(AUTO_BLOCKS.keys())

sdxl_node = SDXLAutoBlocks()
# we can also use the same block to make a refiner node, but you need to load a different unet/config later with 
refiner_node = SDXLAutoBlocks()

# lora_node for lora related things
lora_node = ModularPipeline.from_block(StableDiffusionXLLoraStep())
# IPAdapater nodes for IPAdapter related things
ip_adapter_node = ModularPipeline.from_block(StableDiffusionXLIPAdapterStep())

# step2: load models into the nodes (sdxl_node and refiner nodes are made with same block but need different components)
...
sdxl_node.update_states(**components.get(["unet", "scheduler", "vae", "controlnet"]))
refiner_node.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)
...

#step3:  generate embeddings to reuse them
text_state = text_node(prompt=,,,)
image_state = image_node(image=...)
ip_adapter_state = ip_adapter_node(...)

# step4: re-use embeddings in different workflows, change call parameters, or take the latent to use for a different workflow before decode
latents_img2img = sdxl_node(**text_state.intermediates, **image_state.intermediates, output="latents")
latents_text2img_28steps = sdxl_node(**text_state.intermediates, num_inference_steps = 28, ..., output="latents")
latents_text2img_ipa = sdxl_node(**text_state.intermedaites, **ip_adapter_embeddings, ..., output="latents)
latents_refined = refiner_node(**text_state.intermediates, image_latents=latents_xx, output="latents)
...

# step5: decode once it is ready to decode
image = decoder_node(latents=latents_refined, output="images").images
image[0]

With this setup, you precompute embeddings and reuse them across different denoise backends or with different inference parameters such as guidance_scale, num_inference_steps, or use different schedulers. You can modify your workflow by simply adding/removing/swapping blocks without recomputing the entire pipeline over and over again.

check out the full example script here

test2: modular setup This is the full testing script I used for more configuration, including inpainting/refiner/union controlnet/APG
import os
import torch
import numpy as np
import cv2
from PIL import Image

from diffusers import (
    ControlNetModel,
    ModularPipeline,
    UNet2DConditionModel,
    AutoencoderKL,
    ControlNetUnionModel,
)
from diffusers.utils import load_image
from diffusers.guider import PAGGuider, CFGGuider, APGGuider
from diffusers.pipelines.modular_pipeline import SequentialPipelineBlocks
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import AUTO_BLOCKS, IMAGE2IMAGE_BLOCKS, StableDiffusionXLLoraStep

from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from controlnet_aux import LineartAnimeDetector

import logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("diffusers").setLevel(logging.INFO)


# define device and dtype
device = "cuda:0"
dtype = torch.float16
num_images_per_prompt = 1

test_pag = True
test_lora = False


# define output folder
out_folder = "modular_test_outputs"
if os.path.exists(out_folder):
    # Remove all files in the directory
    for file in os.listdir(out_folder):
        file_path = os.path.join(out_folder, file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print(f"Error: {e}")
os.makedirs(out_folder, exist_ok=True)

# functions for memory info
def reset_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def clear_memory():
    torch.cuda.empty_cache()

def print_mem(mem_size, name):
    mem_gb = mem_size / 1024**3
    mem_mb = mem_size / 1024**2
    print(f"- {name}: {mem_gb:.2f} GB ({mem_mb:.2f} MB)")

def print_memory(message=None):
    """
    Print detailed GPU memory statistics for a specific device.
    
    Args:
        device_id (int): GPU device ID
    """
    allocated_mem = torch.cuda.memory_allocated(device)
    reserved_mem = torch.cuda.memory_reserved(device)
    mem_on_device = torch.cuda.mem_get_info(device)[0]
    peak_mem = torch.cuda.max_memory_allocated(device)

    print(f"\nGPU:{device} Memory Status {message}:")
    print_mem(allocated_mem, "allocated memory")
    print_mem(reserved_mem, "reserved memory")
    print_mem(peak_mem, "peak memory")
    print_mem(mem_on_device, "mem on device")

# function to make canny image (for controlnet)
def make_canny(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    return Image.fromarray(image)


# (1)Define inputs
# for text2img/img2img
prompt = "a photo of an astronaut riding a horse on mars"

# for img2img
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
init_image = load_image(url).convert("RGB")
strength = 0.9 

# for controlnet
control_image = make_canny(init_image)
controlnet_conditioning_scale = 0.5  # recommended for good generalization
# for controlnet_union
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
controlnet_union_image = processor(init_image, output_type="pil")

# for inpainting
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

inpaint_image = load_image(img_url).resize((1024, 1024))
inpaint_mask = load_image(mask_url).resize((1024, 1024))
inpaint_control_image = make_canny(inpaint_image)
inpaint_strength = 0.99


# (2) define blocks and nodes(builder)   

all_blocks_map = AUTO_BLOCKS.copy()
# text block
text_block = all_blocks_map.pop("text_encoder")()
# image encoder block
image_encoder_block = all_blocks_map.pop("image_encoder")()
# decoder block
decoder_block = all_blocks_map.pop("decode")()

class SDXLAutoBlocks(SequentialPipelineBlocks):
    block_classes = list(all_blocks_map.values())
    block_names = list(all_blocks_map.keys())
# sdxl main block
sdxl_auto_blocks = SDXLAutoBlocks()

# lora step
lora_step = StableDiffusionXLLoraStep()


image2image_blocks_map = IMAGE2IMAGE_BLOCKS.copy()
# we do not need image_encoder for refiner becuase it takes image_latents (from another pipeline) as input
image_block = image2image_blocks_map.pop("image_encoder")()
# refiner block
class RefinerSteps(SequentialPipelineBlocks):
    block_classes = list(image2image_blocks_map.values())
    block_names = list(image2image_blocks_map.keys())
refiner_block = RefinerSteps()

text_node = ModularPipeline.from_block(text_block)
image_node = ModularPipeline.from_block(image_encoder_block)
sdxl_node = ModularPipeline.from_block(sdxl_auto_blocks)
decoder_node = ModularPipeline.from_block(decoder_block)
refiner_node = ModularPipeline.from_block(refiner_block)
lora_node = ModularPipeline.from_block(lora_step)


# (3) add states to nodes
repo = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_repo = "stabilityai/stable-diffusion-xl-refiner-1.0"
controlnet_repo = "diffusers/controlnet-canny-sdxl-1.0"
inpaint_repo = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
vae_fix_repo = "madebyollin/sdxl-vae-fp16-fix"
controlnet_union_repo = "brad-twinkl/controlnet-union-sdxl-1.0-promax"


components = ComponentsManager()
components.add_from_pretrained(repo, torch_dtype=dtype)


controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype)
refiner_unet = UNet2DConditionModel.from_pretrained(refiner_repo, subfolder="unet", torch_dtype=dtype)
inpaint_unet = UNet2DConditionModel.from_pretrained(inpaint_repo, subfolder="unet", torch_dtype=dtype)
vae_fix = AutoencoderKL.from_pretrained(vae_fix_repo, torch_dtype=dtype)
controlnet_union = ControlNetUnionModel.from_pretrained(controlnet_union_repo, torch_dtype=dtype)

components.add("controlnet", controlnet)
components.add("refiner_unet", refiner_unet)
components.add("inpaint_unet", inpaint_unet)
components.add("controlnet_union", controlnet_union)
components.add("vae_fix", vae_fix)


# you can add guiders to manager too but no need because it was not serialized
pag_guider = PAGGuider(pag_applied_layers="mid")
controlnet_pag_guider = PAGGuider(pag_applied_layers="mid")
cfg_guider = CFGGuider()
controlnet_cfg_guider = CFGGuider()

# load components/config into nodes
text_node.update_states(**components.get(["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]))
image_node.update_states(**components.get(["vae"]))
decoder_node.update_states(vae=components.get("vae"))

sdxl_node.update_states(**components.get(["unet", "scheduler", "vae", "controlnet"]))
refiner_node.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)

lora_node.update_states(**components.get(["unet", "text_encoder", "text_encoder_2"]))

# (4) enable auto cpu offload: automatically offload models when available gpu memory go below a certain threshold
components.enable_auto_cpu_offload(device=device)
print(components)
reset_memory()


# (5) run the workflows
print(f" ")
print(f" text_node:")
print(text_node)
print(f" ")
print(f" generating text embeddings with text_node")
# using text_node to generate text embeddings
text_state = text_node(prompt=prompt)
print(" ")
print(f" components info after run text_node: text_encoder and text_encoder_2 are on device")
print(components)
print(f" ")
print(f" text_state info")
print(text_state)
print(" ")



# using sdxl_node to generate images

# to get info about sdxl_node and how to use it: inputs/outputs/components
# this is an "auto" workflow that works for all use cases: text2img, img2img, inpainting, controlnet, etc.
# so the information might not be super useful for your specific use case, you will find a "trigger inputs" section says this

# Trigger Inputs: {'control_mode', 'control_image', 'image_latents', 'mask'}
#  Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('control_mode')`).
# Check `.doc` of returned object for more information. provided)

print(f" ")
print(f" sdxl_node:")
print(sdxl_node)
print(" ")

# since we want to use text2img use case, we can run the following to see components/blocks/inputs for this use case
print(f" ")
print(f" sdxl_node info (default use case: text2img)")
print(sdxl_node.get_execution_blocks())
print(" ")


# test1: text2img use case
# when you run the auto workflow, you will get these logs telling you which blocks are actuallyrunning
# (should match what the sdxl_node told you)
# Running block: StableDiffusionXLBeforeDenoiseStep, trigger: None
# Running block: StableDiffusionXLDenoiseStep, trigger: None
# Running block: StableDiffusionXLDecodeStep, trigger: None

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test1_out_text2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test1_out_text2img.png")

clear_memory()

# test2: text2img with lora use case
print(f" ")
print(f" running test2: text2img with lora use case")
generator = torch.Generator(device="cuda").manual_seed(0)
lora_node.load_lora_weights("rajkumaralma/dissolve_dust_style", weight_name="ral-dissolve-sdxl.safetensors", adapter_name="ral-dissolve")
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test2_out_text2img_lora_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test2_out_text2img_lora.png")


# test3:text2image with pag
print(f" ")
print(f" running test3:text2image with pag")
if not test_lora:
    lora_node.unload_lora_weights()

sdxl_node.update_states(guider=pag_guider, controlnet_guider=controlnet_pag_guider)
generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    guider_kwargs={"pag_scale": 3.0},
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test3_out_text2img_pag_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test3_out_text2img_pag.png")

clear_memory()
# checkout the components if you want, the models used is moved to devicem some might get offloaded to cpu
# print(components)

# test4: SDXL(text2img) with controlnet

if not test_pag:
    sdxl_node.update_states(guider=cfg_guider, controlnet_guider=controlnet_cfg_guider)
    guider_kwargs = {}
else:
    guider_kwargs = {"pag_scale": 3.0}


# we are going to pass a new input now `control_image` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" sdxl_node info (controlnet use case)")
print(sdxl_node.get_execution_blocks("control_image"))
print(" ")

print(f" ")
print(f" running test4: SDXL(text2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test4_out_text2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test4_out_text2img_control.png")

clear_memory()

# test5: SDXL(img2img)

# for img2img use case, we encode the image with image_node first, this way we can use the same image_latents for different workflows
# let's checkout the image_node
print(f" image_node info")
print(image_node)
print(" ")


print(f" ")
print(f" running test5: SDXL(img2img)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=init_image, generator=generator)

# let's checkout what's in image_state
print(f" image_state info")
print(image_state)
print(" ")

# let's checkout the sdxl_node info for img2img use case
print(f" sdxl_node info (img2img use case)")
print(sdxl_node.get_execution_blocks("image_latents"))
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    generator=generator, 
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test5_out_img2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test5_out_img2img.png")

clear_memory()

# test6: SDXL(img2img) with controlnet
# let's checkout the sdxl_node info for img2img controlnet use case
print(f" sdxl_node info (img2img controlnet use case)")
print(sdxl_node.get_execution_blocks("image_latents","control_image"))
print(" ")

print(f" ")
print(f" running test6: SDXL(img2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_image=control_image, 
    guider_kwargs=guider_kwargs, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test6_out_img2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test6_out_img2img_control.png")

clear_memory()

# test7: img2img with refiner

# let's checkout the refiner_node
print(f" refiner_node info")
print(refiner_node)
print(" ")

print(f" ")
print(f" running test7: img2img with refiner")

images_output = refiner_node(
    image_latents=latents, 
    prompt=prompt, 
    denoising_start=0.8, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test7_out_img2img_refiner_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test7_out_img2img_refiner.png")

clear_memory()

# test8: SDXL(inpainting)
# let's checkout the sdxl_node info for inpainting use case
print(f" sdxl_node info (inpainting use case)")
print(sdxl_node.get_execution_blocks("mask", "image_latents"))
print(" ")

print(f" ") 
print(f" running test8: SDXL(inpainting)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator)
print(f" image_state info")
print(image_state)
print(" ")
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test8_out_inpainting_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test8_out_inpainting.png")

clear_memory()

# test9: SDXL(inpainting) with controlnet
# let's checkout the sdxl_node info for inpainting + controlnet use case
print(f" sdxl_node info (inpainting + controlnet use case)")
print(sdxl_node.get_execution_blocks("mask", "control_image"))
print(" ")

print(f" ") 
print(f" running test9: SDXL(inpainting) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_image=control_image, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test9_out_inpainting_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test9_out_inpainting_control.png")

clear_memory()

# test10: SDXL(inpainting) with inpaint_unet
print(f" ") 
print(f" running test10: SDXL(inpainting) with inpaint_unet")

sdxl_node.update_states(unet=components.get("inpaint_unet"))
generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    guider_kwargs=guider_kwargs, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    num_images_per_prompt=num_images_per_prompt,
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test10_out_inpainting_inpaint_unet_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test10_out_inpainting_inpaint_unet.png")

clear_memory()


# test11: SDXL(inpainting) with inpaint_unet + padding_mask_crop
print(f" ") 
print(f" running test11: SDXL(inpainting) with inpaint_unet (padding_mask_crop=33)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator, padding_mask_crop=33)
print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)

# we need a different decoder when using padding_mask_crop
print(f" decoder_node info")
print(decoder_node)
print(" ")
print(f" decoder_node info (inpaint/padding_mask_crop)")
print(decoder_node.pipeline_block.blocks["inpaint"])
print(" ")

images_output = decoder_node(latents=latents, crops_coords=image_state.get_intermediate("crops_coords"), **image_state.inputs, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test11_out_inpainting_inpaint_unet_padding_mask_crop_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test11_out_inpainting_inpaint_unet_padding_mask_crop.png")

clear_memory()


# test12: apg

print(f" ")
print(f" running test12: apg")

apg_guider = APGGuider()
sdxl_node.update_states(guider=apg_guider, unet=components.get("unet"))


generator = torch.Generator().manual_seed(0)
latents= sdxl_node(
  **text_state.intermediates,
  generator=generator,
  num_inference_steps=20,
  num_images_per_prompt=1, # yiyi: apg does not work with num_images_per_prompt > 1
  guidance_scale=15,
  height=896,
  width=768,
  guider_kwargs={
      "adaptive_projected_guidance_momentum": -0.5,
      "adaptive_projected_guidance_rescale_factor": 15.0,
  },
  output="latents"
)


images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test12_out_apg_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test12_out_apg.png")

clear_memory()


# test13: SDXL(text2img) with controlnet_union

sdxl_node.update_states(controlnet=components.get("controlnet_union"), unet=components.get("unet"), guider=pag_guider, controlnet_guider=controlnet_pag_guider)
image_node.update_states(vae=components.get("vae_fix"))
decoder_node.update_states(vae=components.get("vae_fix"))

# we are going to pass a new input now `control_mode` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" sdxl_node info (controlnet union use case)")
print(sdxl_node.get_execution_blocks("control_mode"))
print(" ")

print(f" ")
print(f" running test13: SDXL(text2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)

latents = sdxl_node(
    **text_state.intermediates, 
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test13_out_text2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test13_out_text2img_control_union.png")

clear_memory()


# test14: SDXL(img2img) with controlnet_union
print(f" image_node info(with vae_fix for controlnet union)")
print(image_node)
print(" ")


print(f" ")
print(f" sdxl_node info (img2img controlnet union use case)")
print(sdxl_node.get_execution_blocks("image_latents", "control_mode"))
print(" ")

print(f" ")
print(f" running test14: SDXL(img2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=init_image, generator=generator)

print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test14_out_img2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test14_out_img2img_control_union.png")

clear_memory()

# test15: SDXL(inpainting) with controlnet_union
print(f" ")
print(f" sdxl_node info (inpainting controlnet union use case)")
print(sdxl_node.get_execution_blocks("mask", "control_mode"))
print(" ")

print(f" ")
print(f" running test15: SDXL(inpainting) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator)

print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test15_out_inpainting_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test15_out_inpainting_control_union.png")

clear_memory()

print_memory("the end")

print(f" components info after the end")
print(components)
test3: modular setup with IPAdapter
import os
import torch
import numpy as np
import cv2
from PIL import Image

from diffusers import (
    ControlNetModel,
    ModularPipeline,
    UNet2DConditionModel,
    AutoencoderKL,
    ControlNetUnionModel,
)
from diffusers.utils import load_image
from diffusers.guider import PAGGuider, CFGGuider, APGGuider
from diffusers.pipelines.modular_pipeline import SequentialPipelineBlocks
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import AUTO_BLOCKS, IMAGE2IMAGE_BLOCKS, StableDiffusionXLLoraStep, StableDiffusionXLIPAdapterStep

from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor

from controlnet_aux import LineartAnimeDetector

import logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("diffusers").setLevel(logging.INFO)


# define device and dtype
device = "cuda:0"
dtype = torch.float16
num_images_per_prompt = 1

test_pag = True
test_lora = False


# define output folder
out_folder = "modular_test_outputs_0121_ipa"
if os.path.exists(out_folder):
    # Remove all files in the directory
    for file in os.listdir(out_folder):
        file_path = os.path.join(out_folder, file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print(f"Error: {e}")
os.makedirs(out_folder, exist_ok=True)

# functions for memory info
def reset_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def clear_memory():
    torch.cuda.empty_cache()

def print_mem(mem_size, name):
    mem_gb = mem_size / 1024**3
    mem_mb = mem_size / 1024**2
    print(f"- {name}: {mem_gb:.2f} GB ({mem_mb:.2f} MB)")

def print_memory(message=None):
    """
    Print detailed GPU memory statistics for a specific device.
    
    Args:
        device_id (int): GPU device ID
    """
    allocated_mem = torch.cuda.memory_allocated(device)
    reserved_mem = torch.cuda.memory_reserved(device)
    mem_on_device = torch.cuda.mem_get_info(device)[0]
    peak_mem = torch.cuda.max_memory_allocated(device)

    print(f"\nGPU:{device} Memory Status {message}:")
    print_mem(allocated_mem, "allocated memory")
    print_mem(reserved_mem, "reserved memory")
    print_mem(peak_mem, "peak memory")
    print_mem(mem_on_device, "mem on device")

# function to make canny image (for controlnet)
def make_canny(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    return Image.fromarray(image)


# (1)Define inputs
# for text2img/img2img
prompt = "a bear sitting in a chair drinking a milkshake"
negative_prompt = "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality"

# for img2img
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
init_image = load_image(url).convert("RGB")
strength = 0.9 

# for controlnet
control_image = make_canny(init_image)
controlnet_conditioning_scale = 0.5  # recommended for good generalization
# for controlnet_union
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
controlnet_union_image = processor(init_image, output_type="pil")

# for inpainting
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

inpaint_image = load_image(img_url).resize((1024, 1024))
inpaint_mask = load_image(mask_url).resize((1024, 1024))
inpaint_control_image = make_canny(inpaint_image)
inpaint_strength = 0.99

# for ip adapter
ip_adapter_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")


# (2) define blocks and nodes(builder)   

all_blocks_map = AUTO_BLOCKS.copy()
# text block
text_block = all_blocks_map.pop("text_encoder")()
# image encoder block
image_encoder_block = all_blocks_map.pop("image_encoder")()
# decoder block
decoder_block = all_blocks_map.pop("decode")()

class SDXLAutoBlocks(SequentialPipelineBlocks):
    block_classes = list(all_blocks_map.values())
    block_names = list(all_blocks_map.keys())
# sdxl main block
sdxl_auto_blocks = SDXLAutoBlocks()

# lora step
lora_step = StableDiffusionXLLoraStep()

# ip adapter step
ip_adapter_step = StableDiffusionXLIPAdapterStep()


image2image_blocks_map = IMAGE2IMAGE_BLOCKS.copy()
# we do not need image_encoder for refiner becuase it takes image_latents (from another pipeline) as input
image_block = image2image_blocks_map.pop("image_encoder")()
# refiner block
class RefinerSteps(SequentialPipelineBlocks):
    block_classes = list(image2image_blocks_map.values())
    block_names = list(image2image_blocks_map.keys())
refiner_block = RefinerSteps()

text_node = ModularPipeline.from_block(text_block)
image_node = ModularPipeline.from_block(image_encoder_block)
sdxl_node = ModularPipeline.from_block(sdxl_auto_blocks)
decoder_node = ModularPipeline.from_block(decoder_block)
refiner_node = ModularPipeline.from_block(refiner_block)
lora_node = ModularPipeline.from_block(lora_step)
ip_adapter_node = ModularPipeline.from_block(ip_adapter_step)


# (3) add states to nodes
repo = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_repo = "stabilityai/stable-diffusion-xl-refiner-1.0"
controlnet_repo = "diffusers/controlnet-canny-sdxl-1.0"
inpaint_repo = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
vae_fix_repo = "madebyollin/sdxl-vae-fp16-fix"
controlnet_union_repo = "brad-twinkl/controlnet-union-sdxl-1.0-promax"
ip_adapter_repo = "h94/IP-Adapter"


components = ComponentsManager()
components.add_from_pretrained(repo, torch_dtype=dtype)


controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype)
refiner_unet = UNet2DConditionModel.from_pretrained(refiner_repo, subfolder="unet", torch_dtype=dtype)
inpaint_unet = UNet2DConditionModel.from_pretrained(inpaint_repo, subfolder="unet", torch_dtype=dtype)
vae_fix = AutoencoderKL.from_pretrained(vae_fix_repo, torch_dtype=dtype)
controlnet_union = ControlNetUnionModel.from_pretrained(controlnet_union_repo, torch_dtype=dtype)

components.add("controlnet", controlnet)
components.add("refiner_unet", refiner_unet)
components.add("inpaint_unet", inpaint_unet)
components.add("controlnet_union", controlnet_union)
components.add("vae_fix", vae_fix)

image_encoder = CLIPVisionModelWithProjection.from_pretrained(ip_adapter_repo, subfolder="sdxl_models/image_encoder", torch_dtype=dtype)
feature_extractor = CLIPImageProcessor(size=224, crop_size=224)

components.add("image_encoder", image_encoder)
components.add("feature_extractor", feature_extractor)

# you can add guiders to manager too but no need because it was not serialized
pag_guider = PAGGuider(pag_applied_layers="mid")
controlnet_pag_guider = PAGGuider(pag_applied_layers="mid")
cfg_guider = CFGGuider()
controlnet_cfg_guider = CFGGuider()

# load components/config into nodes
text_node.update_states(**components.get(["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]))
image_node.update_states(**components.get(["vae"]))
decoder_node.update_states(vae=components.get("vae"))

sdxl_node.update_states(**components.get(["unet", "scheduler", "vae", "controlnet"]))
refiner_node.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)

lora_node.update_states(**components.get(["unet", "text_encoder", "text_encoder_2"]))
ip_adapter_node.update_states(**components.get(["unet", "image_encoder", "feature_extractor"]))

# (4) enable auto cpu offload: automatically offload models when available gpu memory go below a certain threshold
components.enable_auto_cpu_offload(device=device)
print(components)
reset_memory()


# (5) run the workflows
print(f" ")
print(f" text_node:")
print(text_node)
print(f" ")
print(f" generating text embeddings with text_node")
# using text_node to generate text embeddings
text_state = text_node(prompt=prompt, negative_prompt=negative_prompt)
print(" ")
print(f" components info after run text_node: text_encoder and text_encoder_2 are on device")
print(components)
print(f" ")
print(f" text_state info")
print(text_state)
print(" ")


# use ip adapter to get image embeddings
print(f" ")
print(f" ip_adapter_node:")
print(ip_adapter_node)
print(f" ")
print(f" generating ip adapter image embeddings with ip_adapter_node")
ip_adapter_node.load_ip_adapter(ip_adapter_repo, subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
ip_adapter_node.set_ip_adapter_scale(0.6)
ip_adapter_state = ip_adapter_node(ip_adapter_image=ip_adapter_image)
print(f" ")
print(f" ip_adapter_state info")
print(ip_adapter_state)
print(" ")


# using sdxl_node to generate images
print(f" ")
print(f" sdxl_node:")
print(sdxl_node)
print(" ")

# since we want to use text2img use case, we can run the following to see components/blocks/inputs for this use case
print(f" ")
print(f" sdxl_node info (default use case: text2img)")
print(sdxl_node.get_execution_blocks())
print(" ")


# test1: text2img use case
# when you run the auto workflow, you will get these logs telling you which blocks are actuallyrunning
# (should match what the sdxl_node told you)
# Running block: StableDiffusionXLBeforeDenoiseStep, trigger: None
# Running block: StableDiffusionXLDenoiseStep, trigger: None
# Running block: StableDiffusionXLDecodeStep, trigger: None

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    **ip_adapter_state.intermediates,
    generator=generator, 
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test1_out_text2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test1_out_text2img.png")

clear_memory()

# test2: text2img with lora use case
print(f" ")
print(f" running test2: text2img with lora use case")
generator = torch.Generator(device="cuda").manual_seed(0)
lora_node.load_lora_weights("rajkumaralma/dissolve_dust_style", weight_name="ral-dissolve-sdxl.safetensors", adapter_name="ral-dissolve")
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    **ip_adapter_state.intermediates,
    generator=generator, 
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test2_out_text2img_lora_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test2_out_text2img_lora.png")


# test3:text2image with pag
print(f" ")
print(f" running test3:text2image with pag")
if not test_lora:
    lora_node.unload_lora_weights()

sdxl_node.update_states(guider=pag_guider, controlnet_guider=controlnet_pag_guider)
generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    num_images_per_prompt=num_images_per_prompt,
    **ip_adapter_state.intermediates,
    generator=generator, 
    guider_kwargs={"pag_scale": 3.0},
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test3_out_text2img_pag_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test3_out_text2img_pag.png")

clear_memory()
# checkout the components if you want, the models used is moved to devicem some might get offloaded to cpu
# print(components)

# test4: SDXL(text2img) with controlnet

if not test_pag:
    sdxl_node.update_states(guider=cfg_guider, controlnet_guider=controlnet_cfg_guider)
    guider_kwargs = {}
else:
    guider_kwargs = {"pag_scale": 3.0}


# we are going to pass a new input now `control_image` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" sdxl_node info (controlnet use case)")
print(sdxl_node.get_execution_blocks("control_image"))
print(" ")

print(f" ")
print(f" running test4: SDXL(text2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    control_image=control_image, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    **ip_adapter_state.intermediates,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test4_out_text2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test4_out_text2img_control.png")

clear_memory()

# test5: SDXL(img2img)

# for img2img use case, we encode the image with image_node first, this way we can use the same image_latents for different workflows
# let's checkout the image_node
print(f" image_node info")
print(image_node)
print(" ")


print(f" ")
print(f" running test5: SDXL(img2img)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=init_image, generator=generator)

# let's checkout what's in image_state
print(f" image_state info")
print(image_state)
print(" ")

# let's checkout the sdxl_node info for img2img use case
print(f" sdxl_node info (img2img use case)")
print(sdxl_node.get_execution_blocks("image_latents"))
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    generator=generator, 
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test5_out_img2img_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test5_out_img2img.png")

clear_memory()

# test6: SDXL(img2img) with controlnet
# let's checkout the sdxl_node info for img2img controlnet use case
print(f" sdxl_node info (img2img controlnet use case)")
print(sdxl_node.get_execution_blocks("image_latents","control_image"))
print(" ")

print(f" ")
print(f" running test6: SDXL(img2img) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    control_image=control_image, 
    guider_kwargs=guider_kwargs, 
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=strength, 
    num_images_per_prompt=num_images_per_prompt,
    generator=generator, 
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test6_out_img2img_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test6_out_img2img_control.png")

clear_memory()

# test7: img2img with refiner

# let's checkout the refiner_node
print(f" refiner_node info")
print(refiner_node)
print(" ")

print(f" ")
print(f" running test7: img2img with refiner")

images_output = refiner_node(
    image_latents=latents, 
    prompt=prompt, 
    denoising_start=0.8, 
    generator=generator, 
    num_images_per_prompt=num_images_per_prompt,
    output="images"
)
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test7_out_img2img_refiner_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test7_out_img2img_refiner.png")

clear_memory()

# test8: SDXL(inpainting)
# let's checkout the sdxl_node info for inpainting use case
print(f" sdxl_node info (inpainting use case)")
print(sdxl_node.get_execution_blocks("mask", "image_latents"))
print(" ")

print(f" ") 
print(f" running test8: SDXL(inpainting)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator)
print(f" image_state info")
print(image_state)
print(" ")
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test8_out_inpainting_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test8_out_inpainting.png")

clear_memory()

# test9: SDXL(inpainting) with controlnet
# let's checkout the sdxl_node info for inpainting + controlnet use case
print(f" sdxl_node info (inpainting + controlnet use case)")
print(sdxl_node.get_execution_blocks("mask", "control_image"))
print(" ")

print(f" ") 
print(f" running test9: SDXL(inpainting) with controlnet")

generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    control_image=control_image, 
    guider_kwargs=guider_kwargs, 
    num_images_per_prompt=num_images_per_prompt,
    controlnet_conditioning_scale=controlnet_conditioning_scale, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test9_out_inpainting_control_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test9_out_inpainting_control.png")

clear_memory()

# test10: SDXL(inpainting) with inpaint_unet
print(f" ") 
print(f" running test10: SDXL(inpainting) with inpaint_unet")

sdxl_node.update_states(unet=components.get("inpaint_unet"))
generator = torch.Generator(device="cuda").manual_seed(0)
latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    guider_kwargs=guider_kwargs, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    num_images_per_prompt=num_images_per_prompt,
    output="latents"
)
images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test10_out_inpainting_inpaint_unet_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test10_out_inpainting_inpaint_unet.png")

clear_memory()


# test11: SDXL(inpainting) with inpaint_unet + padding_mask_crop
print(f" ") 
print(f" running test11: SDXL(inpainting) with inpaint_unet (padding_mask_crop=33)")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator, padding_mask_crop=33)
print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    num_images_per_prompt=num_images_per_prompt,
    guider_kwargs=guider_kwargs, 
    strength=inpaint_strength,  # make sure to use `strength` below 1.0
    output="latents"
)

# we need a different decoder when using padding_mask_crop
print(f" decoder_node info")
print(decoder_node)
print(" ")
print(f" decoder_node info (inpaint/padding_mask_crop)")
print(decoder_node.pipeline_block.blocks["inpaint"])
print(" ")

images_output = decoder_node(latents=latents, crops_coords=image_state.get_intermediate("crops_coords"), **image_state.inputs, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test11_out_inpainting_inpaint_unet_padding_mask_crop_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test11_out_inpainting_inpaint_unet_padding_mask_crop.png")

clear_memory()


# test12: apg

print(f" ")
print(f" running test12: apg")

apg_guider = APGGuider()
sdxl_node.update_states(guider=apg_guider, unet=components.get("unet"))


generator = torch.Generator().manual_seed(0)
latents= sdxl_node(
  **text_state.intermediates,
  generator=generator,
  num_inference_steps=20,
  num_images_per_prompt=1, # yiyi: apg does not work with num_images_per_prompt > 1
  guidance_scale=15,
  **ip_adapter_state.intermediates,
  height=896,
  width=768,
  guider_kwargs={
      "adaptive_projected_guidance_momentum": -0.5,
      "adaptive_projected_guidance_rescale_factor": 15.0,
  },
  output="latents"
)


images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test12_out_apg_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test12_out_apg.png")

clear_memory()


# test13: SDXL(text2img) with controlnet_union

sdxl_node.update_states(controlnet=components.get("controlnet_union"), unet=components.get("unet"), guider=pag_guider, controlnet_guider=controlnet_pag_guider)
image_node.update_states(vae=components.get("vae_fix"))
decoder_node.update_states(vae=components.get("vae_fix"))

# we are going to pass a new input now `control_mode` so the workflow will be automatically converted to controlnet use case
# let's checkout the info for controlnet use case
print(f" sdxl_node info (controlnet union use case)")
print(sdxl_node.get_execution_blocks("control_mode"))
print(" ")

print(f" ")
print(f" running test13: SDXL(text2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)

latents = sdxl_node(
    **text_state.intermediates, 
    **ip_adapter_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test13_out_text2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test13_out_text2img_control_union.png")

clear_memory()


# test14: SDXL(img2img) with controlnet_union
print(f" image_node info(with vae_fix for controlnet union)")
print(image_node)
print(" ")


print(f" ")
print(f" sdxl_node info (img2img controlnet union use case)")
print(sdxl_node.get_execution_blocks("image_latents", "control_mode"))
print(" ")

print(f" ")
print(f" running test14: SDXL(img2img) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=init_image, generator=generator)

print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test14_out_img2img_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test14_out_img2img_control_union.png")

clear_memory()

# test15: SDXL(inpainting) with controlnet_union
print(f" ")
print(f" sdxl_node info (inpainting controlnet union use case)")
print(sdxl_node.get_execution_blocks("mask", "control_mode"))
print(" ")

print(f" ")
print(f" running test15: SDXL(inpainting) with controlnet_union")

generator = torch.Generator(device="cuda").manual_seed(0)
image_state = image_node(image=inpaint_image, mask_image=inpaint_mask, height=1024, width=1024, generator=generator)

print(f" image_state info")
print(image_state)
print(" ")

latents = sdxl_node(
    **text_state.intermediates, 
    **image_state.intermediates,
    **ip_adapter_state.intermediates,
    control_mode=[3],
    control_image=[controlnet_union_image], 
    num_images_per_prompt=num_images_per_prompt,
    height=1024,
    width=1024,
    generator=generator,
    output="latents"
)

images_output = decoder_node(latents=latents, output="images")
for i, image in enumerate(images_output.images):
    image.save(f"{out_folder}/test15_out_inpainting_control_union_{i}.png")
print(f" save modular output ({len(images_output.images)} images) to {out_folder}/test15_out_inpainting_control_union.png")

clear_memory()

print_memory("the end")

print(f" components info after the end")
print(components)

Developer Guide: Building with Modular Diffusers

Core Components Overview

The Modular Diffusers architecture consists of four main components:

ModularPipeline

The main interface for creating and running modular pipelines. Unlike traditional pipelines, you don't write it from scratch - it builds itself from pipeline blocks! Example usage:

from diffusers import ModularPipeline
pipe = ModularPipeline.from_block(auto_pipeline_block)
images = pipe(prompt="a cat", num_inference_steps=15, output="images")

PipelineBlock

The fundamental building block, similar to a mellon/comfy node. Each block:

  • Defines required components, inputs, and outputs
  • Implements __call__(pipeline, state) -> (pipeline, state)
  • Can be reused across different pipelines
  • Can be combined with other blocks

MultiPipelineBlocks

Combines multiple blocks into a bigger one! These combined blocks behave just like single blocks - with their own inputs, outputs, and components, but they are able to handle more complex workflows!

We have two types of MultiPipelineBlocks available, you can use them to combine individual blocks into ready-to-use sets (Like LEGO® presets!)

  1. SequentialPipelineBlocks

    • Chains blocks in sequential order
    class StableDiffusionXLMainSteps(SequentialPipelineBlocks):
        block_classes = [InputStep, SetTimestepsStep, ...]
        block_names = ["input", "set_timesteps", ...]
  2. AutoPipelineBlocks

    • Provides conditional block selection, AutoPipelineBlocks makes the complex if.. else.. logic in your code disappear! with this, you can write blocks for specific use case to keep your code path clean; and use AutoPipelineBlocks to combine blocks into convenient presets that can provide a better user experience :)
    • In this example the ControlNetDenoiseStep step will be dispatched when "control_image" is passed from the user, otherwise, it will run the default DenoseStep
    class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
        block_classes = [ ControlNetDenoiseStep, DenoiseStep]
        block_names = [ "controlnet", "unet"]
        block_trigger_inputs = ["control_image", None]

PipelineState and BlockStates

PipelineState and BlockStates manage dataflow between/inside blocks; they make debugging really easy! feel free to print out them at any given time to have an overview of all the shapes/types/values of your pipeline/block states

Differential Diffusion Example

Here we'll show you a new way to build with Modular Diffusers. Let's look at implementing a Differential Diffusion pipeline as an example. (https://differential-diffusion.github.io/). It is, in a sense, an image-to-image workflow, so we can start with the preset of pipeline blocks we used to build our current img2img pipeline (IMAGE2IMAGE_BLOCKS) and see how we can build this new pipeline with them!

IMAGE2IMAGE_BLOCKS = OrderedDict([
    ("text_encoder", StableDiffusionXLTextEncoderStep),
    ("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
    ("image_encoder", StableDiffusionXLVaeEncoderStep),
    ("input", StableDiffusionXLInputStep),
    ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
    ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
    ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
    ("denoise", StableDiffusionXLDenoiseStep),
    ("decode", StableDiffusionXLDecodeStep)
])

It seems like we can reuse the "text_encoder", "ip_adapter", "image_encoder", "input", "prepare_add_cond" and "decode" steps from img2img workflow out-of-box. The "set_timesteps" step in Differential Diffusion is the same as the one we use for text-to-image (i.e. it does not take strength parameter), so we just use StableDiffusionXLSetTimestepsStep. It uses a different denoising method so we will need to write a new "denoise" step, and the "prepare_latents" step is also a little bit different, so we will write a new one too.

Here are the changes needed to create the Differential Diffusion version of these blocks:

  1. Modified StableDiffusionXLImg2ImgPrepareLatentsStep :
  class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
      expected_components = ["vae", "scheduler"]
      model_name = "stable-diffusion-xl"
  
      @property
      def description(self) -> str:
          return (
-             "Step that prepares the latents for the image-to-image generation process"
+             "Step that prepares the latents for the differential diffusion generation process"
          )
  
      @property
      def intermediates_inputs(self) -> List[InputParam]:
          return [
-             InputParam("latent_timestep", required=True, type_hint=torch.Tensor, description="The timestep that represents the initial noise level for image-to-image/inpainting generation."),
+             InputParam("timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for sampling. Can be generated in set_timesteps step."),
              InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation."),
              InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt."),
              InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")]
  
      def __call__(self, pipeline, state: PipelineState) -> PipelineState:
          data = self.get_block_state(state)
          data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype
          data.device = pipeline._execution_device
          data.add_noise = True if data.denoising_start is None else False
+         pipeline.scheduler.set_begin_index(None)
          if data.latents is None:
              data.latents = pipeline.prepare_latents_img2img(
                  data.image_latents,
-                 data.latent_timestep,
+                 data.timesteps,
                  data.batch_size,
                  data.num_images_per_prompt,
                  data.dtype,
                  data.device,
                  data.generator,
                  data.add_noise,
              )
  1. Modified StableDiffusionXLDenoiseStep step: we remove inpaint-related logics and added diff-diff specific logic
  class SDXLDiffDiffDenoiseStep(PipelineBlock):
      expected_components = ["unet", "scheduler", "guider"]
      model_name = "stable-diffusion-xl"
  
      @property
      def description(self) -> str:
          return (
-             "Step that iteratively denoise the latents for the text-to-image/image-to-image/inpainting generation process"
+             "Step that iteratively denoise the latents for the image generation process using differential diffusion"
          )

      @property
      def inputs(self) -> List[Tuple[str, Any]]:
          return [
              # ... common parameters ...
+             InputParam("diffdiff_map", required=True),
+             InputParam("denoising_start"),
          ]

      def __init__(self):
          super().__init__()
          self.components["guider"] = CFGGuider()
          self.components["scheduler"] = None
          self.components["unet"] = None
+         self.auxiliaries["mask_processor"] = VaeImageProcessor(do_normalize=False, do_convert_grayscale=True)

      @torch.no_grad()
      def __call__(self, pipeline, state: PipelineState) -> PipelineState:
          # ... setup code ...

+         # preparations for diff diff
+         data.latent_height = data.image_latents.shape[-2]
+         data.latent_width = data.image_latents.shape[-1]
+         data.diffdiff_map = pipeline.mask_processor.preprocess(data.diffdiff_map, height=data.latent_height, width=data.latent_width)
+         
+         data.diffdiff_map = data.diffdiff_map.squeeze(0).to(data.device)
+         thresholds = torch.arange(data.num_inference_steps, dtype=data.diffdiff_map.dtype) / data.num_inference_steps
+         data.thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(data.device)
+         data.masks = data.diffdiff_map > (data.thresholds + (data.denoising_start or 0))
+
+         data.original_with_noise = data.latents

          with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
              for i, t in enumerate(data.timesteps):
+                 # diff diff
+                 if i == 0 and data.denoising_start is None:
+                     data.latents = data.original_with_noise[:1]
+                 else:
+                     data.mask = data.masks[i].unsqueeze(0)
+                     data.mask = data.mask.to(data.latents.dtype)
+                     data.mask = data.mask.unsqueeze(1)  # fit shape
+                     data.latents = data.original_with_noise[i] * data.mask + data.latents * (1 - data.mask)

                  # ... rest of denoising loop ...
-                 if data.num_channels_unet == 4 and data.mask is not None and data.image_latents is not None:
-                     data.init_latents_proper = data.image_latents
-                     if i < len(data.timesteps) - 1:
-                         data.noise_timestep = data.timesteps[i + 1]
-                         data.init_latents_proper = pipeline.scheduler.add_noise(
-                             data.init_latents_proper, data.noise, torch.tensor([data.noise_timestep])
-                         )
-                     data.latents = (1 - data.mask) * data.init_latents_proper + data.mask * data.latents

That's all there is to it! Once you've made these 2 diff-diff blocks, you can create a preset(pre-assembled sets of blocks) and then build your pipeline from it.

# create diff-diff preset
DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]

class DiffDiffBlocks(SequentialPipelineBlocks):
    block_classes = list(DIFFDIFF_BLOCKS.values())
    block_names = list(DIFFDIFF_BLOCKS.keys())

# create diff-diff pipeline from preset
diffdiff_blocks = DiffDiffBlocks()
dd_node = ModularPipeline.from_block(diffdiff_blocks)

to use it

dd_node.update_states(**components.components)

prompt = "a green pear"
negative_prompt = "blurry"

image = dd_node(
    prompt=prompt,
    negative_prompt=negative_prompt,
    diffdiff_map=mask,
    image=image,
    output="images"
).images[0]

diff-diff-out

Complete Example: Implementing Differential Diffusion Pipeline
from diffusers.pipelines.modular_pipeline import PipelineBlock, SequentialPipelineBlocks, PipelineState, InputParam, OutputParam
from diffusers.guider import CFGGuider
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import DPMSolverMultistepScheduler

import torch
from typing import List, Tuple, Any, Optional, Dict, Union

class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
    expected_components = ["vae", "scheduler"]
    model_name = "stable-diffusion-xl"

    @property
    def description(self) -> str:
        return (
            "Step that prepares the latents for the differential diffusion generation process"
        )

    @property
    def inputs(self) -> List[Tuple[str, Any]]:
        return [
            InputParam(
                "generator", 
                type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], 
                description="One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) "
                           "to make generation deterministic."
            ),
            InputParam(
                "latents", 
                type_hint=Optional[torch.Tensor], 
                description="Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`."
            ),
            InputParam(
                "num_images_per_prompt", 
                default=1, 
                type_hint=int, 
                description="The number of images to generate per prompt"
            ),
            InputParam(
                "denoising_start", 
                type_hint=Optional[float], 
                description="When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. The initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. Useful for 'Mixture of Denoisers' multi-pipeline setups."
            ),
        ]

    @property
    def intermediates_inputs(self) -> List[InputParam]:
        return [
            InputParam("timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for sampling. Can be generated in set_timesteps step."), 
            InputParam("image_latents", required=True, type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."), 
            InputParam("batch_size", required=True, type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."), 
            InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs")]

    @property
    def intermediates_outputs(self) -> List[OutputParam]:
        return [OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process")]

    def __init__(self):
        super().__init__()
        self.components["scheduler"] = None

    @torch.no_grad()
    def __call__(self, pipeline, state: PipelineState) -> PipelineState:
        data = self.get_block_state(state)

        data.dtype = data.dtype if data.dtype is not None else pipeline.vae.dtype
        data.device = pipeline._execution_device
        data.add_noise = True if data.denoising_start is None else False
        pipeline.scheduler.set_begin_index(None)
        if data.latents is None:
            data.latents = pipeline.prepare_latents_img2img(
                data.image_latents,
                data.timesteps,
                data.batch_size,
                data.num_images_per_prompt,
                data.dtype,
                data.device,
                data.generator,
                data.add_noise,
            )

        self.add_block_state(state, data)

        return pipeline, state


class SDXLDiffDiffDenoiseStep(PipelineBlock):
    expected_components = ["unet", "scheduler", "guider"]
    model_name = "stable-diffusion-xl"

    @property
    def description(self) -> str:
        return (
            "Step that iteratively denoise the latents for the image generation process using differential diffusion"
        )

    @property
    def inputs(self) -> List[Tuple[str, Any]]:
        return [
            InputParam(
                "guidance_scale", 
                type_hint=float,
                default=5.0,
                description="Guidance scale as defined in Classifier-Free Diffusion Guidance. Higher values encourage images closely linked to the text prompt, potentially at the expense of image quality. Enabled when > 1."
            ),
            InputParam(
                "guidance_rescale",
                type_hint=float,
                default=0.0,
                description="Guidance rescale factor (φ) to fix overexposure when using zero terminal SNR, as proposed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed'."
            ),
            InputParam(
                "cross_attention_kwargs",
                type_hint=Optional[Dict[str, Any]],
                default=None,
                description="Optional kwargs dictionary passed to the AttentionProcessor."
            ),
            InputParam(
                "generator",
                type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
                description="One or a list of torch generator(s) to make generation deterministic."
            ),
            InputParam(
                "eta",
                type_hint=float,
                default=0.0,
                description="Parameter η in the DDIM paper. Only applies to DDIMScheduler, ignored for others."
            ),
            InputParam(
                "guider_kwargs",
                type_hint=Optional[Dict[str, Any]],
                default=None,
                description="Optional kwargs dictionary passed to the Guider."
            ),
            InputParam(
                "num_images_per_prompt",
                type_hint=int,
                default=1,
                description="The number of images to generate per prompt."
            ),
            InputParam("diffdiff_map",required=True),
            InputParam("denoising_start"),
        ]

    @property
    def intermediates_inputs(self) -> List[str]:
        return [
            InputParam(
                "latents", 
                required=True, 
                type_hint=torch.Tensor, 
                description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
            ),
            InputParam(
                "batch_size", 
                required=True, 
                type_hint=int, 
                description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."
            ),
            InputParam(
                "timesteps", 
                required=True, 
                type_hint=torch.Tensor, 
                description="The timesteps to use for the denoising process. Can be generated in set_timesteps step."
            ),
            InputParam(
                "num_inference_steps", 
                required=True, 
                type_hint=int, 
                description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step."
            ),
            InputParam(
                "pooled_prompt_embeds", 
                required=True, 
                type_hint=torch.Tensor, 
                description="The pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step."
            ),
            InputParam(
                "negative_pooled_prompt_embeds", 
                type_hint=Optional[torch.Tensor], 
                description="The negative pooled prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step.    "
            ),
            InputParam(
                "add_time_ids", 
                required=True, 
                type_hint=torch.Tensor, 
                description="The time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step."
            ),
            InputParam(
                "negative_add_time_ids", 
                type_hint=Optional[torch.Tensor], 
                description="The negative time ids to use as additional conditioning for the denoising process. Can be generated in prepare_additional_conditioning step."
            ),
            InputParam(
                "prompt_embeds", 
                required=True, 
                type_hint=torch.Tensor, 
                description="The prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step."
            ),
            InputParam(
                "negative_prompt_embeds", 
                type_hint=Optional[torch.Tensor], 
                description="The negative prompt embeddings to use to condition the denoising process. Can be generated in text_encoder step.   "
            ),
            InputParam(
                "timestep_cond", 
                type_hint=Optional[torch.Tensor], 
                description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step."
            ),
            InputParam(
                "image_latents", 
                type_hint=Optional[torch.Tensor], 
                description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step."
            ),
            InputParam(
                "ip_adapter_embeds", 
                type_hint=Optional[torch.Tensor], 
                description="The ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step."
            ),
            InputParam(
                "negative_ip_adapter_embeds", 
                type_hint=Optional[torch.Tensor], 
                description="The negative ip adapter embeddings to use to condition the denoising process, need to have ip adapter model loaded. Can be generated in ip_adapter step."
            ),
        ]

    @property
    def intermediates_outputs(self) -> List[OutputParam]:
        return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]

    def __init__(self):
        super().__init__()
        self.components["guider"] = CFGGuider()
        self.components["scheduler"] = None
        self.components["unet"] = None
        self.auxiliaries["mask_processor"] = VaeImageProcessor(do_normalize=False, do_convert_grayscale=True)
    
    @torch.no_grad()
    def __call__(self, pipeline, state: PipelineState) -> PipelineState:

        data = self.get_block_state(state)

        data.num_channels_unet = pipeline.unet.config.in_channels
        data.disable_guidance = True if pipeline.unet.config.time_cond_proj_dim is not None else False
        data.device = pipeline._execution_device

        # adding default guider arguments: do_classifier_free_guidance, guidance_scale, guidance_rescale
        data.guider_kwargs = data.guider_kwargs or {}
        data.guider_kwargs = {
            **data.guider_kwargs,
            "disable_guidance": data.disable_guidance,
            "guidance_scale": data.guidance_scale,
            "guidance_rescale": data.guidance_rescale,
            "batch_size": data.batch_size * data.num_images_per_prompt,
        }

        pipeline.guider.set_guider(pipeline, data.guider_kwargs)
        # Prepare conditional inputs using the guider
        data.prompt_embeds = pipeline.guider.prepare_input(
            data.prompt_embeds,
            data.negative_prompt_embeds,
        )
        data.add_time_ids = pipeline.guider.prepare_input(
            data.add_time_ids,
            data.negative_add_time_ids,
        )
        data.pooled_prompt_embeds = pipeline.guider.prepare_input(
            data.pooled_prompt_embeds,
            data.negative_pooled_prompt_embeds,
        )

        data.added_cond_kwargs = {
            "text_embeds": data.pooled_prompt_embeds,
            "time_ids": data.add_time_ids,
        }

        if data.ip_adapter_embeds is not None:
            data.ip_adapter_embeds = pipeline.guider.prepare_input(data.ip_adapter_embeds, data.negative_ip_adapter_embeds)
            data.added_cond_kwargs["image_embeds"] = data.ip_adapter_embeds

        # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        data.extra_step_kwargs = pipeline.prepare_extra_step_kwargs(data.generator, data.eta)
        data.num_warmup_steps = max(len(data.timesteps) - data.num_inference_steps * pipeline.scheduler.order, 0)

        # preparations for diff diff
        data.latent_height = data.image_latents.shape[-2]
        data.latent_width = data.image_latents.shape[-1]
        data.diffdiff_map = pipeline.mask_processor.preprocess(data.diffdiff_map, height=data.latent_height, width=data.latent_width)
        
        data.diffdiff_map = data.diffdiff_map.squeeze(0).to(data.device)
        thresholds = torch.arange(data.num_inference_steps, dtype=data.diffdiff_map.dtype) / data.num_inference_steps
        data.thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(data.device)
        data.masks = data.diffdiff_map > (data.thresholds + (data.denoising_start or 0))

        data.original_with_noise = data.latents

        with pipeline.progress_bar(total=data.num_inference_steps) as progress_bar:
            for i, t in enumerate(data.timesteps):
    
                # diff diff
                if i == 0 and data.denoising_start is None:
                    data.latents = data.original_with_noise[:1]
                else:
                    data.mask = data.masks[i].unsqueeze(0)
                    # cast mask to the same type as latents etc
                    data.mask = data.mask.to(data.latents.dtype)
                    data.mask = data.mask.unsqueeze(1)  # fit shape
                    data.latents = data.original_with_noise[i] * data.mask + data.latents * (1 - data.mask)
                # end diff diff
        
                # expand the latents if we are doing classifier free guidance
                data.latent_model_input = pipeline.guider.prepare_input(data.latents, data.latents)
                data.latent_model_input = pipeline.scheduler.scale_model_input(data.latent_model_input, t)

                # predict the noise residual
                data.noise_pred = pipeline.unet(
                    data.latent_model_input,
                    t,
                    encoder_hidden_states=data.prompt_embeds,
                    timestep_cond=data.timestep_cond,
                    cross_attention_kwargs=data.cross_attention_kwargs,
                    added_cond_kwargs=data.added_cond_kwargs,
                    return_dict=False,
                )[0]
                # perform guidance
                data.noise_pred = pipeline.guider.apply_guidance(
                    data.noise_pred,
                    timestep=t,
                    latents=data.latents,
                )
                # compute the previous noisy sample x_t -> x_t-1
                data.latents_dtype = data.latents.dtype
                data.latents = pipeline.scheduler.step(data.noise_pred, t, data.latents, **data.extra_step_kwargs, return_dict=False)[0]
                if data.latents.dtype != data.latents_dtype:
                    if torch.backends.mps.is_available():
                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
                        data.latents = data.latents.to(data.latents_dtype)

                if i == len(data.timesteps) - 1 or ((i + 1) > data.num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
                    progress_bar.update()

        pipeline.guider.reset_guider(pipeline)
        self.add_block_state(state, data)

        return pipeline, state



from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_modular import IMAGE2IMAGE_BLOCKS, TEXT2IMAGE_BLOCKS
from diffusers.pipelines.modular_pipeline import ModularPipeline
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers.utils import load_image

from torchvision import transforms
import torchvision


DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]


DIFFDIFF_CORE_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()


class DiffDiffBlocks(SequentialPipelineBlocks):
    block_classes = list(DIFFDIFF_BLOCKS.values())
    block_names = list(DIFFDIFF_BLOCKS.keys())


diffdiff_blocks = DiffDiffBlocks()
dd_node = ModularPipeline.from_block(diffdiff_blocks)

components = ComponentsManager()
components.add_from_pretrained("SG161222/RealVisXL_V4.0", torch_dtype=torch.float16, variant="fp16")

components.enable_auto_cpu_offload()

dd_node.update_states(**components.components)

print(dd_node)


image = load_image(
        "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true"
    )

mask = load_image(
        "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true"
    )

prompt = "a green pear"
negative_prompt = "blurry"

image = dd_node(
    prompt=prompt,
    negative_prompt=negative_prompt,
    guidance_scale=7.5,
    num_inference_steps=25,
    diffdiff_map=mask,
    image=image,
    output="images"
).images[0]

image.save("diffdiff_out.png")

Diffusers as seen in nodes

coming up soon....

Next Steps

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yoland68
Copy link

Very cool!

@oozzy77
Copy link

oozzy77 commented Oct 30, 2024

hi this is very interesting! I'm making a Python pipeline flow visual scripting tool, that can auto-convert functions to visual nodes for fast and modular UI blocks demo. Itself is a pip package: https://pypi.org/project/nozyio/

I wanted to integrate diffusers with my flow nodes UI project but found its not very modular. But this PR may change that! Looking forward to see how this evolves.

github: https://github.com/oozzy77/nozyio happy to connect!

@yiyixuxu
Copy link
Collaborator Author

@oozzy77 thanks!
do you want to join a slack channel with me? if you want to experiment building something with this PR I'm eager to hear your feedback and iterate base on that

@oozzy77
Copy link

oozzy77 commented Oct 31, 2024 via email

@yiyixuxu
Copy link
Collaborator Author

@oozzy77 I sent an invite!

@yiyixuxu yiyixuxu added the roadmap Add to current release roadmap label Dec 4, 2024
@hlky hlky mentioned this pull request Dec 5, 2024
@yiyixuxu yiyixuxu requested a review from stevhliu February 4, 2025 01:13
@yiyixuxu yiyixuxu changed the title [WIP] The Modular Diffusers The Modular Diffusers Feb 4, 2025
Copy link
Collaborator

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, thanks @yiyixuxu!

My first comments are regarding pipeline functions like encode_prompt, encode_image, prepare_ip_adapter_image_embeds and related modules. We can remove everything related to num_images_per_prompt as its handled by StableDiffusionXLInputStep and I think we could make these functions work with a single input then call separately with positive and negative prompt/image from the module.

For example, with do_classifier_free_guidance prepare_ip_adapter_image_embeds returns a list of concatenated embeds that we chunk in StableDiffusionXLIPAdapterStep, but in encode_image we just use zeros_like for the unconditional (zeros_like through image_encoder when output_hidden_states). Instead of having code in encode_image and prepare_ip_adapter_image_embeds to handle this we can pass zeros_like to prepare_ip_adapter_image_embeds from StableDiffusionXLIPAdapterStep and we can allow experimentation with actual negative ip adapter image embeds, a custom module for that would currently be possible but unintuitive as we'd need to pass a negative ip adapter image yet take the positive embeds output as the negative embeds.

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Feb 4, 2025

@hlky

We can remove everything related to num_images_per_prompt as its handled by StableDiffusionXLInputStep and I think we could make these functions work with a single input then call separately with positive and negative prompt/image from the module

totally agree, I was thinking about that too! do you want to take a stab on that? we need to refactor these functions from regular pipeline too

@hlky
Copy link
Collaborator

hlky commented Feb 4, 2025

@yiyixuxu Yes I'll work on that

@a-r-r-o-w
Copy link
Member

Super cool @yiyixuxu @asomoza @hlky! Not reviewing the PR yet since I'm getting a feel for how a developer would be interacting with the library, but I personally found it very intuitive to get started from the examples.

Here's my first try at making a modular diffusers workflow for naive latent upscaling with SDXL:

Code
import torch
import torch.nn.functional as F
from diffusers import ModularPipeline, StableDiffusionXLAutoPipeline
from diffusers.pipelines.components_manager import ComponentsManager

# Load models
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
components.enable_auto_cpu_offload(device="cuda:0")

# Create pipeline
pipe = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
pipe.update_states(**components.components)
pipe.to("cuda")

# Run inference
prompt = "A majestic lion jumping from a big stone at night"
height = 1024
width = 1024
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30)
images = output.intermediates.get("images").images
latents = output.intermediates.get("latents")
images[0].save("output.png")

# Latent upscale
# Note that only naive upscaling is done here. Alternatively, a latent upscaler
# model could be used
batch_size, num_channels, latent_height, latent_width = latents.shape
scale_factor = 1.5
upscaled_height, upscaled_width = int(height * scale_factor), int(width * scale_factor)
upscaled_latent_height, upscaled_latent_width = int(latent_height * scale_factor), int(latent_width * scale_factor)
upscaled_latents = F.interpolate(latents, size=(upscaled_latent_height, upscaled_latent_width), mode="nearest-exact")

# Run inference with upscaled latents
strength = 0.5
upscaled_output = pipe(prompt=prompt, image_latents=upscaled_latents, height=upscaled_height, width=upscaled_width, num_inference_steps=40, strength=strength)

images = upscaled_output.intermediates.get("images").images
images[0].save("output_upscaled.png")

On my first try, I passed latents=upscaled_latents instead of image_latents=upscaled_latents, which does not work as expected (does not trigger the SDXL img2img blocks). Since I have the advantage of knowing the library beforehand, I could make an educated guess about the image_latents parameter or quickly find out by looking at the code.

I wonder if things like this may cause some friction in getting started with modular diffusers workflows. In this case, do you think renaming image_latents to latents is suitable choice to make? Not quite sure why the two are distinguished at the moment, but will take a look at the code soon to better understand.

@a-r-r-o-w
Copy link
Member

Question: Let's say we implemented a Flux/SD3 equivalent of the SDXL modular blocks. Now I want to do the same latent upscale thing in the above comment.

To make it possible to upscale latents with every supported model, I would like to create a general purpose node/block, with different possible init configurations, that takes a ndim=4 latent and upscales it based on the init configs - either naively or using a latent upscaler model. I expect this block to be invoked before the denoiser steps begin. Let's also assume that I have created the auto-pipe instances for both, similar to what's shown in the examples.

How would I go about inserting my custom blocks into the pipeline execution flow? Or, what would the plan of action on the developers' end look like if they want to inject some code before/after each atomic pipeline step that we currently have (vae encode/decode, latent prep, denoise step, ...)?

@@ -46,6 +46,7 @@
"AutoPipelineForInpainting",
"AutoPipelineForText2Image",
]
_import_structure["modular_pipeline"] = ["ModularPipeline"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add components_manager and at parent level because in the example we are using

from diffusers import ComponentsManager

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Feb 5, 2025

@a-r-r-o-w
we sort of have inconsistent parameter names across different pipelines right now, with modular, same parameters will need to be combined into one, so I guess we will have to pick a name to stick to

latent is one example

there is also image and control_image: it is called image for text-to-image controlnet but control_image in image-to-image when there is already an image variable

in your case for upscaling, I think it should be image_latents (currently image in our pipelines), no? (latent in general should include the noise, at least conceptually, even though in some case, we don't need to add noise to it at all in prepare_latent process). it is indeed very confusing. and I understand that the output latents has different meanings from input latents in our pipelines, that's not ideal,
maybe we can:

  • rename our current latents to init_noise or something
  • latents is the initial latents that used in denoising loop (it may or may not include noise) - it could be same as image_latents or init_noise
  • image_latent is the encoded image

open to suggestions/discussions

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Feb 5, 2025

@a-r-r-o-w

if it is an upscaler that takes latents as input, I think it is most convenient to be used on its own, (like in UI, it would be its own node/pipeline)

maybe make a map like this so it can be used to create different presets?

AUTO_UPSCALE_BLOCKS = OrderedDict([
    ("text_encoder", StableDiffusionXLTextEncoderStep),
    ("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
    ("image_encoder", StableDiffusionXLAutoVaeEncoderStep),
    ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep),
    ("upscale", AutoUpscaleStep),
    ("denoise", StableDiffusionXLAutoDenoiseStep),
    ("decode", StableDiffusionXLAutoDecodeStep)
])

make a preset for end-to-end pipeline

class SDXLAutoUpscaleBlocks(SequentialPipelineBlocks):
    block_classes = list(AUTO_UPSCALE_BLOCKS.values())
    block_names = list(AUTO_UPSCALE_BLOCKS.keys())

auto_pipe_upscaled = ModularPipeline.from_block(SDXLAutoUpscaleBlocks())

just the upscaler node used in stand-alone

upscaler_block = AUTO_UPSCALE_BLOCKS["upscale"]()
upcaler_node = ModularPipeline.from_block(upscaler_block)

@sayakpaul
Copy link
Member

Did a pass on the examples and the info shared instead of looking through the code too much (following @a-r-r-o-w's philosophy).

Some comments first.

auto_pipe.update_states(**components.components) -- should this be called auto_pipe.update_components()? update_states() seems a bit counterintuitive?

The pipeline automatically adapts to your inputs:

What if the user combines the inputs that are supported? How do we infer for such situations? For example, what if I provide a control_image and prompt?

print(auto_pipe)

This is very convenient! However, I wonder if the user could restrict the level of info they want to see. I got a bit lost after the args started appearing. Maybe something to consider in the later iterations.

Misc:

  • Similar to get_execution_blocks(), would it make sense to provide a list_execution_blocks() method?
  • intermediates seems to be a very useful attribute that could benefit from some explicit documentation.

Now, I tried to use the SDXL refiner:

Code
import torch
from diffusers import ModularPipeline, StableDiffusionXLAutoPipeline
from diffusers.pipelines.components_manager import ComponentsManager

# Load models
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)

# Create pipeline
pipe = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
pipe.update_states(**components.components)
pipe.to("cuda")

# Run inference
prompt = "A majestic lion jumping from a big stone at night"
height = 1024
width = 1024
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30)
images = output.intermediates.get("images").images
latents = output.intermediates.get("latents")
print(f"{latents.shape=}")
images[0].save("output_modular.png")

# Clear things
del components, pipe
torch.cuda.empty_cache()

# Load refiner
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16)

# Create pipeline
pipe = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
pipe.update_states(**components.components)
pipe.to("cuda")
pipe.register_to_config(requires_aesthetics_score=False)

# Refine outputs.
output = pipe(prompt=prompt, image_latents=latents, num_inference_steps=30)
images = output.intermediates.get("images").images
images[0].save("output_refiner_modular.png")

It leads to:

ValueError: Model expects an added time embedding vector of length 2560, but a vector of 2816 was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` (1024, 1024) is correctly used by the model.

Questions:

  • What am I doing wrong?
  • Is there a better way of using the refiner with modular diffusers? SDXL base and refiner share some components but it wasn't clear to me how to make it work with a workflow similar to the Diff-Diff one. Some guidance would be nice.

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Feb 11, 2025

@sayakpaul these are really good feedbacks! thank you!

for refiner, you have to do

refiner_pipeline.update_states(**components.get(["text_encoder_2","tokenizer_2", "vae", "scheduler"]), unet=components.get("refiner_unet"), force_zeros_for_empty_prompt=True, requires_aesthetics_score=True)

it is a bit verbose as you can see, and it's the case in general on how we load the ModularPipeline, so I'm wildly open to suggestions for improvement in that aspect. One idea I want to play around with is to introduce "collection" on components manager (not a good name since it means something different on hub but just the idea to allow users to operates on a group of model components at once, with some pipeline config attached to it) - will push a POC soon

auto_pipe.update_states(**components.components) -- should this be called auto_pipe.update_components()

open to better API, but probably not components because we also update config with it

What if the user combines the inputs that are supported? How do we infer for such situations? For example, what if I provide a control_image and prompt?

open to suggestions on how to do better here, currently each pipelineblock has a description attribute and it is up to the developer to document about workflows that are supported and their respective inputs

This is very convenient! However, I wonder if the user could restrict the level of info they want to see. I got a bit lost after the args started appearing. Maybe something to consider in the later iterations.

These are pretty important! We don't have to wait to improve in later iterations. Let's make it better now if it's possible. maybe we don't have to print out the docstring (the args etc), we can direct user to use .doc to get them?

@sayakpaul
Copy link
Member

Thanks Yiyi!

With your suggestion, I could successfully do my first outputs powered by modular diffusers:

updated code
import torch
from diffusers import ModularPipeline, StableDiffusionXLAutoPipeline
from diffusers.pipelines.components_manager import ComponentsManager

# Load models
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)

# Create pipeline
pipe = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
pipe.update_states(**components.components)
pipe.to("cuda")

# Run inference
prompt = "A majestic lion jumping from a big stone at night"
height = 1024
width = 1024
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30)
images = output.intermediates.get("images").images
latents = output.intermediates.get("latents")
print(f"{latents.shape=}")
images[0].save("output_modular.png")

# Clear things
del components, pipe
torch.cuda.empty_cache()

# Load refiner
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16)

# Create pipeline
refiner_pipeline = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
refiner_pipeline.update_states(
    **components.get(["text_encoder_2", "tokenizer_2", "vae", "scheduler"]), 
    unet=components.get("unet"), 
    force_zeros_for_empty_prompt=True, 
    requires_aesthetics_score=True
)
refiner_pipeline.to("cuda")
# Refine outputs.
output = refiner_pipeline(prompt=prompt, image_latents=latents, num_inference_steps=30)
images = output.intermediates.get("images").images
images[0].save("output_refiner_modular.png")
refiner base
image image

it is a bit verbose as you can see, and it's the case in general on how we load the ModularPipeline, so I'm wildly open to suggestions for improvement in that aspect. One idea I want to play around with is to introduce "collection" on components manager (not a good name since it means something different on hub but just the idea to allow users to operates on a group of model components at once, with some pipeline config attached to it) - will push a POC soon

I am fine with verbosity if it teaches the user about how to correctly modify things. Maybe the error message could better reflect how to properly do the update_states() step if that seems feasible at all? Otherwise, it feels like guesswork (or perhaps I am not well-equipped to understand the flow yet).

open to better API, but probably not components because we also update config with it

Oh then. Then probably update_attributes()?

open to suggestions on how to do better here, currently each pipelineblock has a description attribute and it is up to the developer to document about workflows that are supported and their respective inputs

I noticed it after I commented that. I think this sufficient for now. (no strong opinions) Should we maybe enforce some kind of input validator (validate_inputs(), e.g.) so that different similar inputs don't interfere with each other's scopes?

These are pretty important! We don't have to wait to improve in later iterations. Let's make it better now if it's possible. maybe we don't have to print out the docstring (the args etc), we can direct user to use .doc to get them?

Perfect, this sounds very good!

@yiyixuxu
Copy link
Collaborator Author

@sayakpaul
I looked at the code you linked here I think you don't need to remove the components manager and reload everything again.

# Loading Models
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)

# load just the refiner UNet (reuse the text_encoders that's already in components)
+ refiner_unet = UNet2DConditionModel.from_pretrained(
+     "stabilityai/stable-diffusion-xl-refiner-1.0", 
+     subfolder="unet", 
+     torch_dtype=torch.float16
+ )
+ components.add("refiner_unet", refiner_unet)
# this make sure all models stay in cpu until forward pass is invoked and may be put back on cpu when more GPU memory is needed
+ components.enable_auto_cpu_offload()

# I think we don't need to do this:
# 1. pipe's states are managed by `components`; if we want to delete everything, delete components in components manager is enough
# 2. GPU memory is already managed by `components`, i.e. if we need more memory to run refiner pipeline,
#    the other unet from base repo will be offload to cpu.
#    We can also add methods to unload/delete models if more explicit control is needed but overall I think we don't need to 
#    delete a model unless we are certain we do not need them anymore
# 3. in this particular use case, we still need the text_encoders so don't recommend deleting them and reloading again here
- # Clear components and free CUDA memory before loading refiner
- del components, pipe
- torch.cuda.empty_cache()
- 
- # Load complete refiner pipeline
- components = ComponentsManager()
- components.add_from_pretrained(
-     "stabilityai/stable-diffusion-xl-refiner-1.0", 
-     torch_dtype=torch.float16
- )

# Refiner Pipeline Setup
refiner_pipeline = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
refiner_pipeline.update_states(
    **components.get(["text_encoder_2", "tokenizer_2", "vae", "scheduler"]),
+   unet=components.get("refiner_unet"),  # Using explicitly loaded UNet
-   unet=components.get("unet"),  # Using UNet from complete pipeline
    force_zeros_for_empty_prompt=True,
    requires_aesthetics_score=True
)
Click to expand the code
import torch
from diffusers import ModularPipeline, StableDiffusionXLAutoPipeline, UNet2DConditionModel
from diffusers.pipelines.components_manager import ComponentsManager

# Load models
components = ComponentsManager()
components.add_from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)

refiner_unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", subfolder="unet", torch_dtype=torch.float16)
components.add("refiner_unet", refiner_unet)
components.enable_auto_cpu_offload()

# Create pipeline
pipe = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
pipe.update_states(**components.components)
pipe.to("cuda")

# Run inference
prompt = "A majestic lion jumping from a big stone at night"
height = 1024
width = 1024
output = pipe(prompt=prompt, height=height, width=width, num_inference_steps=30)
images = output.intermediates.get("images").images
latents = output.intermediates.get("latents")
print(f"{latents.shape=}")
images[0].save("output_modular.png")


# Create pipeline
refiner_pipeline = ModularPipeline.from_block(StableDiffusionXLAutoPipeline())
refiner_pipeline.update_states(
    **components.get(["text_encoder_2", "tokenizer_2", "vae", "scheduler"]), 
    unet=components.get("refiner_unet"), 
    force_zeros_for_empty_prompt=True, 
    requires_aesthetics_score=True
)
refiner_pipeline.to("cuda")
# Refine outputs.
output = refiner_pipeline(prompt=prompt, image_latents=latents, num_inference_steps=30)
images = output.intermediates.get("images").images
images[0].save("output_refiner_modular.png")

can you help me:

  1. look into if there any benefit in deleting the models when switching workflows?In general, I think it is more efficient to offload them to cpu when you work with multiple workflows but want to see if there is any use case we missed
  2. how can we do better in docs for this?

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Feb 11, 2025

@sayakpaul
for the other comments

h then. Then probably update_attributes()?

could be just update too - I will keep this open since it will be very easy to change names later!

Should we maybe enforce some input validator (validate_inputs()

happy to explore this too, if you can share a POC that'd be great!

@sayakpaul
Copy link
Member

look into if there any benefit in deleting the models when switching workflows? In general, I think it is more efficient to offload them to cpu when you work with multiple workflows but want to see if there is any use case we missed

I think this is a valid assumption except for the situations where we don't have enough CPU RAM (48GBs might be low).

how can we do better in docs for this?

I think we could cover the refiner use case (and alike) under the theme of "reusing components between workflows". We could make it clear that to make the most out of reusing, it's recommended to first load all the components needed for the workflows users want to try out and keep them on CPU. Users will always have the option to load any ad-hoc component component they may may have forgotten in the beginning. If we can make this clear in the docs with examples, I think that should be enough. WDYT?

could be just update too - I will keep this open since it will be very easy to change names later!

Yeah update() is potentially simpler. SGTM!

happy to explore this too, if you can share a POC that'd be great!

Sure, happy to do that. I will branch off of this PR and try to open a PR. Would that work?

@asomoza
Copy link
Member

asomoza commented Feb 12, 2025

I finished testing and doing a PoC with the callbacks so I can update the step progress inside an UI. So discussing here a question about the implementation, since we now have the data object, I would love if we could pass around the whole object instead of the current method (which I found restrictive) where we need to enable which variables we want to expose to the callbacks but this won't be compatible with the current callbacks.

So I did this for the PoC to match current implementation:

                if data.callback_on_step_end is not None:
                    callback_kwargs = {}
                    for k in data.callback_on_step_end_tensor_inputs:
                        callback_kwargs[k] = getattr(data, k)
                    callback_outputs = data.callback_on_step_end(self, i, t, callback_kwargs)

                    data.latents = callback_outputs.pop("latents", data.latents)
                    data.prompt_embeds = callback_outputs.pop("prompt_embeds", data.prompt_embeds)
                    data.added_cond_kwargs["text_embeds"] = callback_outputs.pop("text_embeds", data.added_cond_kwargs["text_embeds"])
                    data.added_cond_kwargs["time_ids"] = callback_outputs.pop("time_ids", data.added_cond_kwargs["time_ids"])

but it could be something like this which is better to me:

                if data.callback_on_step_end is not None:
                    data.callback_on_step_end(self, i, t, data)

what are your thoughts on this @yiyixuxu?

@yiyixuxu
Copy link
Collaborator Author

@asomoza
second one for sure! that's the point, callback should be super easy now

                if data.callback_on_step_end is not None:
                    data.callback_on_step_end(self, i, t, data)

@yiyixuxu
Copy link
Collaborator Author

@sayakpaul
souds good!

Sure, happy to do that. I will branch off of this PR and try to open a PR. Would that work?

@DN6
Copy link
Collaborator

DN6 commented Feb 13, 2025

It's looking really nice. Obviously there are a lot of intricacies here that I might not have picked up, so in my initial pass I just tried to focus on parts that felt a little unclear to me.

I tried to break it down by the major components in Modular Diffusers.

Components Manager

My understanding here is that Components Manager is responsible for loading all models, schedulers, etc into the Modular Pipeline and performing memory management for the loaded components.

Where it felt a bit unintuitive was trying determine which model repos can be used with add_from_pretrained and which ones cannot.

For example, This snippet will load all the components of the base SDXL Pipelines into Component Manager

# Load models
components = ComponentsManager()
components.add_from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
)

But if I want to load a ControlNet Model via a model repo I cannot. I have to create the object and add to Components Manager via the add method.

components.add_from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16)

Since I'm familiar with the library, I realise that this is following our existing Pipeline loading logic. But I think it might make sense to support adding individual model components through add_from_pretrained as well. We may need to introduce AutoModel logic or something to make it happen.

PipelineBlock

My understanding here is that a PipelineBlock is not expected to load any models, but instead only runs a computation step using the preloaded models in the ComponentsManager or perhaps some custom code. I also this from a user perspective, most people building with Modular will mostly likely be developing new block types.

The PipelineBlock is also meant to be stateless and all stateful operations are managed through the PipelineState or BlockState?

Let's say I want to add a PipelineBlock that has a model associated with the step. In the example below I want to create block that automatically extracts a depth map from an image so that I can use it with a ControlNet.

Can I add the depth model to the ComponentManager from the block, in a manner similar to register_to_config? Or should I always add the model to the ComponentManager and then update the Block state? What is the correct way to create a block with an associated model?

class DepthBlock(PipelineBlock):
    @property
    def inputs(self) -> List[InputParam]:
        control_image = InputParam(
            name="control_image",
            required=True,
        )
        return control_image

    def __init__(self) -> None:
        super().__init__()
        # If I load in a model in pipeline block is it possible to move the the componets manager?
        depth_preprocessor = DepthPreprocessor.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf")

    def __call__(self, pipeline, state: PipelineState) -> PipelineState:
        data = self.get_block_state(state)
        control_image = data.control_image
        depth_image = self.depth_processor(control_image)
        data.control_image = depth_image
        
        self.add_block_state(data, state)

        return pipeline, state 

When initializing PipelineBlocks we have this line

class StableDiffusionXLDecodeLatentsStep(PipelineBlock):
    expected_components = ["vae"]
    model_name = "stable-diffusion-xl"

And then in the __init__ we have

    def __init__(self):
        super().__init__()
        self.components["vae"] = None
        self.auxiliaries["image_processor"] = VaeImageProcessor(vae_scale_factor=8)

I found it a bit confusing as to why we are setting self.component["vae"] = None during the init of the PipelineBlock, because based on the class attribute it feels like it should be initialized with something? Additionally, self.components doesn't seem to be used anywhere in the __call__ so it's application or use feels a bit unclear.

Are the class attributes at the top of the block needed? As far as I can tell from skimming the code, we operate on block instances everywhere? Can we define PipelineBlocks in such a way? IMO a bit more Pythonic and makes the Blocks feel a bit more like mini-Pipelines. You can also add type enforcement check on the components too. LMK if I'm missing something here.

class StableDiffusionXLTextEncoderStep(PipelineBlock):
	def __init__(
		self,
		text_encoder=None,
		text_encoder_2=None,
		tokenizer=None,
		tokenizer_2=None,
		force_zeros_for_empty_prompt=True,
	):
		super().__init__()

        # this would set expected_configs
        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
        
        # this would set expected_components
        self.register_component({
            text_encoder=text_encoder,
            text_encoder_2=text_encoder_2,
            tokenizer=tokenizer,
            tokenizer_2=tokenizer_2
        })

Another thing I wasn't quite able to figure out the exact scope of PipelineBlock. Should it operate as a single atomic unit or be aware of the global pipeline methods that are available?

Here let's say we are encoding a prompt. In the example SDXLTextEncoderStep this is done in the following way

        (
            data.prompt_embeds,
            data.negative_prompt_embeds,
            data.pooled_prompt_embeds,
            data.negative_pooled_prompt_embeds,
        ) = pipeline.encode_prompt(
            data.prompt,
            data.prompt_2,
            data.device,
            1,
            data.do_classifier_free_guidance,
            data.negative_prompt,
            data.negative_prompt_2,
            prompt_embeds=None,
            negative_prompt_embeds=None,
            pooled_prompt_embeds=None,
            negative_pooled_prompt_embeds=None,
            lora_scale=data.text_encoder_lora_scale,
            clip_skip=data.clip_skip,
        )

The encode_prompt method is defined at the global pipeline level. So if I'm trying to understand the Block I need to hop back and forth between the block and StableDiffusionXLModularPipeline. What if I want to create a custom prompt encoding method for the pipeline? Should I define it inside the block? Or do I have to rewrite StableDiffusionXLModularPipeline with a new method?

Can encode_prompt be defined and executed inside the block itself. In this case, when you read the StableDiffusionXLTextEncoderStep you get a full understanding of what is happening. If you need to access the encoding method from the ModularPipeline instance, you could do something like

my_modular_pipe.pipeline_block['text_encoder_step'].encode_prompt()

I think Modular actually supports this workflow already.

Is it also considered bad practice to set components as attributes in the blocks as use them that way? Something like?

	@torch.no_grad()
	def __call__(self, pipeline, state: PipelineState) -> PipelineState:
		# Get inputs and intermediates
		data = self.get_block_state(state)
		self.check_inputs(pipeline, data)
		prompt_embeds = self.text_encoder(data.prompt)

Regarding Auxillaries, Is there a strong reason to not have these objects just be considered components as well?

Auto Workflow

I am a little apprehensive about introducing Auto workflows in V1. IMO it's better to let users get
used to the mechanics of using Modular manually before introducing any "magic". But I will leave this to your discretion.

Modular Pipeline, Block State, Pipeline State

I like these a lot and I'm pretty much aligned on how they work.

One small nit that is unrelated to the actual functionality (just putting out here for consideration)
Would prefer that we use block_state instead of data for this variable

	@torch.no_grad()
	def __call__(self, pipeline, state: PipelineState) -> PipelineState:
		# Get inputs and intermediates
		data = self.get_block_state(state)

Obviously the work here is very extensive and I'm still playing around with it. LMK if I've misunderstood some concepts or if I should open PRs to try and clarify any of these points.

@yiyixuxu
Copy link
Collaborator Author

@DN6

Thanks! These are super nice feedback! I'll address all of them, but I want to focus on PipelineBlock first because I think it is where most confusion comes from, and it indicates to me that this is where most work needs to be done to improve it!

I just had enough time to think about these 2 aspect you mentioned: (1) the design choice on making pipeline blocks stateless and (2) the class attribute vs __init__ method confusion, so I will share my thoughts on them first!

1. Stateless Design Choice

Yes, in the current design, Pipelineblocks (PipelineBlock,SequentialPipelineBlocks, and AutoPipelineBlocks) are meant to be "stateless". Their role is only to:

  • Define the computation steps
  • Specify the requirement (models/config/inputs/outputs) needed for these computation steps
  • To be used to compose with other pipeline blocks into a larger workflow

I like to think there are two stages in Modular diffusers:

  1. Composing Stage - at this stage, we do not need to worry about states, we do nothing but composing some "definitions". An example would be like this, when we combine blocks, we acombines the computation step/expected_components/config/inputs/outputs from them. And the composed pipeline blocks remain stateless
# Define the depth block you were working on
class DepthBlock(PipelineBlock):
   ...
# another one for canny images
class CannyBlock(PipelineBlock):
   ...

# Combine these two into one with conditional logic
class AutoControlInputBlock(AutoPipelineBlocks):
    block_classes = [DepthBlock, CannyBlock]
    block_names = ["depth", "canny"]
    block_trigger_inputs = ["depth_image", "canny_image"]

# combine in sequential orders
class CompleteControlNetPipeline(SequentialPipelineBlocks):
    block_classes = [AutoControlInputBlock, PrepareLatentBlock, DenoiseBlock, DecodeBlock]
    block_names = ["control_input", "prepare", "denoise", "decode"]

you can keep composing for as long as you want, but once you're done and you want to use it now, we enter the "Runtime Stage" and that's when the pipeline blocks become stateful

  1. Runtime Stage - we use ModularPipeline to load models and run inference:
# Create Modularpipeline with the block you just made
controlnet_node = ModularPipeline.from_block(CompleteControlNetPipeline())

# Load models and components
controlnet_node.update_states(**components.components)

# Run inference
image = controlnet_node(control_image=my_image, prompt="a cat", output="images")

I made pipeline blocks stateless since model loading isn't needed during composition - it's only required at runtime.

The design you proposed here will make pipeline block stateful. That means each pipeline block will need to manage model components themselves, and you will have to load models into each pipeline blocks and then compose them somehow. It is a possible alternative design, but I think it might need a different system to support it and it is more complex.

class StableDiffusionXLTextEncoderStep(PipelineBlock):
    def __init__(
        self,
        text_encoder=None,
        text_encoder_2=None,
        tokenizer=None,
        tokenizer_2=None,
        force_zeros_for_empty_prompt=True,
    ):
        super().__init__()

        # this would set expected_configs
        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
        
        # this would set expected_components
        self.register_component({
            text_encoder=text_encoder,
            text_encoder_2=text_encoder_2,
            tokenizer=tokenizer,
            tokenizer_2=tokenizer_2
        })

2. Component Initialization and Class Attributes

About your comment on Component Initialization here:

I found it a bit confusing as to why we are setting self.component["vae"] = None during the init of the PipelineBlock, because based on the class attribute it feels like it should be initialized with something? Additionally, self.components doesn't seem to be used anywhere in the __call__ so it's application or use feels a bit unclear.

Are the class attributes at the top of the block needed? As far as I can tell from skimming the code, we operate on block instances everywhere?

I totally agree that it is very confusing that we have both expected_components as class attribute and also have this self.components[x] = None thing during __init__;

The class attributes expected_components are currently used, one example is here

for component_name in self.expected_components:

I like to think these class attributes expected_* act like the signature for __init__ method in our current pipeline, they are just definitions of what's needed, and we need to compose the requirements from all blocks before instantiation.

However, I don't think we need both the class attributeexpected_components and the self.components (I actually have a note here about this too ). I think either one of these two would be sufficient, i.e. we can infer expected_components from self.components or the other way around

I think it might be better to remove the __Init__ method all together for pipeline blocks and only keep these class attributes, given that the pipeline blocks are stateless. Also, we can expand on the expectd_components, e.g. adding something like ComponentSpec that mirrors our model_index.json structure. For the DepthBlock example that you provided, we could have something like this:

class DepthBlock(PipelineBlock):
    expected_components = [
        ComponentSpec(
            name="depth_processor",
            class_name=["depth_anything", "DepthPreprocessor"],
            default_repo="depth-anything/Depth-Anything-V2-Large-hf"
        )
    ]
    
    @property
    def inputs(self) -> List[InputParam]:
        return [InputParam(
            name="control_image",
            required=True,
        )]

    def __call__(self, pipeline, state: PipelineState) -> PipelineState:
        data = self.get_block_state(state)
        depth_image = pipeline.depth_processor(data.control_image)
        data.control_image = depth_image
        self.add_block_state(data, state)
        return pipeline, state

This way, we would also be able to support the use case you described here:

let's say I want to add a PipelineBlock that has a model associated with the step. In the example below I want to create block that automatically extracts a depth map from an image so that I can use it with a ControlNet.
Can I add the depth model to the ComponentManager from the block, in a manner similar to register_to_config? Or should I always add the model to the ComponentManager and then update the Block state? What is the correct way to create a block with an associated model?

currently, indeed, you would always have to add the models to ComponentsManager and then update the state of ModularPipeline after you create it. With the ComponentSpec, we still won't be able to associate a loaded model to a block, but we will be able to associate information related to model and model loading to the pipeline block, and we can automatically loaded them when we create ModularPipeline later in the RUNTIME stage, similar to the behavior of current Pipelines.

What do you think?

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Feb 16, 2025

@DN6
continue to answer your questions!

Scope of Pipeline Block Methods

Regarding your questions about PipelineBlock scope and global pipeline methods, here:

Another thing I wasn't quite able to figure out the exact scope of PipelineBlock. Should it operate as a single atomic unit or be aware of the global pipeline methods that are available?

The encode_prompt method is defined at the global pipeline level. So if I'm trying to understand the Block I need to hop back and forth between the block and StableDiffusionXLModularPipeline. What if I want to create a custom prompt encoding method for the pipeline? Should I define it inside the block? Or do I have to rewrite StableDiffusionXLModularPipeline with a new method?

Can encode_prompt be defined and executed inside the block itself. In this case, when you read the StableDiffusionXLTextEncoderStep you get a full understanding of what is happening.

Yes, you can define methods on pipeline blocks level. Currently, we have two places where methods can live:

  1. custom methods on Pipeline blocks
  2. methods on Pipeline level
    • an example is the encode_prompt example you mentioned
    • I made them global pipeline method for one reason only: to be able to use #Copied from directly and to minimize maintenance cost for us 😛 If you prefer to move these methods to pipeline blocks level so you don't need to hop back and forth, we can totally look into that!

components as attributes in blocks

regarding this question

Is it also considered bad practice to set components as attributes in the blocks as use them that way?

yes, with current design, it would be a bad practice since Pipeline blocks are stateless, and all the model components should be managed at the global pipeline level and passed to each pipeline block at run time through the pipeline parameter

If you think a stateful pipeline block design is more intuitive, I'd be happy to explore with you for that too:)

A few things to keep in mind if we want to explore the alternative stateful design:

  • we should still compose the pipeline first before loading any models. Loading models for each pipeline block is unnecessarily inconvenient for users and pretty complex for us to handle, e.g. the user may load one vae in the encoder block and then a different one in decoder block, and we have to somehow choose one when we compose the blocks.
  • Related to the first point, I think it is okay to make pipeline blocks stateful (i.e., to set components in blocks). However, components need to be managed in the global pipeline scope. I just mentioned the example with vae, you should not have different vae in encoder vs decoder; and you should not have different image processor for image preprocessing vs post processing. and user should not be able to only update the vae on encoder and not the one in decoder ....
  • encode_prompt is one of the few examples that can be run standalone, but not all should. Pipeline blocks are just like Lego bricks; some are meaningful on their own, but mainly, they are meant to be used with others to compose something meaningful. It is a bit hard to decompose our workflow into completely atomic units and logic are sometimes intertwined between blocks, e.g. even though we don't need use unet until much later during the denoise loop, we may need its config info for image preprocessing in the very early step

@yiyixuxu
Copy link
Collaborator Author

@DN6

now for Auto Workflow

I agree it is probably not that important for our current diffuser users, but I consider it crucial for UI use case. Since one of our goal is to eliminate the barrier between us and the UI community/professionals, I think it makes sense for us release with it

let me explain a bit! Auto Workflow fits really really well with how workflows are developed. Alvaro's guides (for example, https://huggingface.co/blog/OzzyGT/outpainting-differential-diffusion) give a pretty good sense of the process. It is usually an iterative process: the user does not necessarily know exactly what's needed in the beginning, so they start with something basic and gradually add/remove features and modify part of the workflows until they get satisfactory results. Without auto workflows, they'd have to rebuild their workflow each time they want to try something different. It is not a very nice experience. With Auto Workflows (node build with auto workflow), they can pretty much just stick to the same node and just change the input nodes as they need.

also, there is the number of nodes. comfy currently faces this challenge that there are too many nodes and it's bit of overwhelming for users. without auto workflow, we'd have the same issue. with auto workflow, we currently have like 5 nodes, prompt_encode/ image_encode/decode/denoise/ip-adapter. so it is very manageable

I think maybe we can have different guides targeting on different user group and only talk about auto workflow for the ones targeting on UI/professionals

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm better understanding some of the things here after working with it for a bit. I'll try to provide some general thoughts and introduce some ideas I had when you were initially starting the modular diffusers development:

  • No strong opinions on whether the PipelineBlocks should be stateful or not. We could ideally support both cases similar to what's done in the Diffusers Hooks.

  • Each PipelineBlock IMO should only contain minimal implementation, i.e. bare-minimum single functionality and not handle too many overlapping cases. For example, pipeline.encode_prompt and similar that operate on both unconditional and conditional branches should probably just support one option prompt. The pipeline block can then invoke this method twice - once for positive, once for negative.

  • If methods like encode_prompt could have a functional equivalent that can be invoked from outside a pipeline/pipeline-blocks, I think it would be super helpful for re-using in trainers instead of rolling our own minified implementation.

  • We should consider batching vs non-batching inference. Currently with existing pipelines, we always batch negative and positive prompt embeds. This increase memory required from intermediate activation states by 2x. For a low VRAM mode, this might be an important consideration. (It's not very important. We can always add a BatchedInferenceHook or something to the model::forward to split the args/kwargs along batch dimension)

  • Currently, the invocation mode is eager. Something like:

    I_AM_AT_BLOCK_X -> DO_I_HAVE_THE_INPUTS_I_REQUIRE? ---> YES ---> PERFORM_COMPUTATION_AND_PROCEED_TO_NEXT_BLOCK
                                                       |--> NO  ---> RAISE_ERROR
    

    If we're somewhere deep inside the execution stage and then error out (maybe due to a missing input), all computation done till now is lost for a silly error. This is very frustrating (I've personally faced it multiple times during model integrations). IMO we have an opportunity to improve this (perhaps, some time in the near future if not for now). Since we already know that each block requires a set of inputs and outputs, regardless of what the other blocks do, we can topologically traverse the graph of blocks in reverse to determine if all inputs/outputs mapping is correct. If not, we can early-error out and let the user know. If yes, we can proceed with computation.
    Note that this won't help identify issues in cases where we simply forgot to pass an input to a model or something, but it'll be helpful in block-development cases -- we're simply doing a static analysis to make sure that the invocation graph makes sense on a high-level from the pipeline one creates.

  • Regarding _execution_device and dtype on pipeline, I think we should remove it and instead infer device/dtype from the module that is going to do the processing next. For example, if my text encoder is in float16 but transformer is in bfloat16, dtype on pipeline will return float16. So, prompt embeds will be in different dtype leading to error on transformer unless we explicitly write some logic to handle this in the pipeline. Writing it per-model block is prone to errors and can introduce lossy conversions, so it might be nice to push to keep the pipeline as a simple container holding modules and remove any notion of module state from it, and handle these device/dtype-changes more centrally (like pipeline.prepare_inputs_for_model(model, inputs)) (just my thoughts and not really at issue here)

  • Have we thought about how a pipeline created by a user can be shared via an exported file, say on the Hub, for ease of distribution?

return noise_cfg


class CFGGuider:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TLDR; let's try to separate algorithms from the modeling/pipeline implementations as much as possible. If we can decouple CFG nicely, I believe we would have lot more composibility and options for testing. Let's try to write these in a manner that works with existing pipelines too if we invoke __call__ with a guider object.

I like this design. For some time now, I've wanted to add support for different guidance techniques (STG, Perturbed Attention guidance, Energy-based CFG, Skip layer guidance, etc.) to all existing models/pipelines where applicable. As I'm working on something similar, I'll share some thoughts.

These techniques are independent of the model/pipeline, so it makes sense to me that we should not tie in that logic too strongly to the pipelines. At the moment, our pipelines only accept parameters like guidance_scale, guidance_rescale, true_cfg_scale, and similar. This is not really scalable if we want composability while supporting latest research techniques. So, this design of being able to initialize "guiders" is super cool, since we can parameterize them however we want and since it's decoupled from the pipelines __call__ and model forward itself.

To provide some more details of what I've been trying , this is some pseudo-code:

from diffusers.hooks import HookRegistry, PerturbedAttentionGuidanceHook

class GuidanceMixin:
    def register_modules(self, denoiser: torch.nn.Module, ...) -> None:
        ...

    def unregister_modules(self, denoiser: torch.nn.Module, ...) -> None:
        ...

    def prepare_inputs(self, **kwargs) -> Any:
        parameters = inspect.signature(self._prepare_inputs).parameters
        ignored_kwargs = {k for k in kwargs.keys() if k not in parameters}
        input_kwargs = {k: v for k, v in kwargs.items() if k in parameters}
        return self._prepare_inputs(**input_kwargs)

    def __call__(self, **kwargs) -> Any:
        parameters = inspect.signature(self.forward).parameters
        ignored_kwargs = {k for k in kwargs.keys() if k not in parameters}
        input_kwargs = {k: v for k, v in kwargs.items() if k in parameters}
        return self.forward(**input_kwargs)

    def _prepare_inputs(self, **kwargs) -> Any:
        raise NotImplementedError


class ClassifierFreeGuidance(GuidanceMixin):
    def __init__(self, scale: float) -> None:
        self.scale = scale

    def _prepare_inputs(self, latents: torch.Tensor, prompt_embeds: torch.Tensor, negative_prompt_embeds: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None) -> torch.Tensor:
        if self.scale > 1.0:
            latents = torch.cat([latents, torch.zeros_like(latents).normal_(generator=generator)])
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
        return {"latents": latents, "prompt_embeds": prompt_embeds}

    def forward(self, x_uncond: torch.Tensor, x_cond: torch.Tensor) -> torch.Tensor:
        return x_uncond + self.scale * (x_cond - x_uncond)


class PerturbedAttentionGuidance(GuidanceMixin):
    def __init__(self, scale: float, cfg_scale: float, layers: Union[str, List[str]]) -> None:
        self.scale = scale
        self.cfg_scale = scale
        self.layers = [layers] if isinstance(layers, str) else layers

    def register_modules(self, denoiser: torch.nn.Module, ...) -> None:
        for name, submodule in denoiser.named_modules():
            if any(regex_match(name, layer_name) for layer_name in self.layers):
                registry = HookRegistry.check_if_exists_or_initialize(submodule)
                hook = PerturbedAttentionGuidanceHook()
                registry.register_hook(hook)

    def prepare_inputs(self, latents: torch.Tensor, prompt_embeds: torch.Tensor, negative_prompt_embeds: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None) -> torch.Tensor:
        num_additional_latents = (self.scale > 1.0) + (self.cfg_scale > 1.0)
        if num_additional_latents > 0:
            additional_latents = [torch.zeros_like(latents).normal_(generator=generator) for _ in range(num_additional_latents)]
            latents = torch.cat([latents, *additional_latents])
            ... # Similarly handle prompt embeddings
        return ...

    def forward(self, x_uncond: torch.Tensor, x_cond: torch.Tensor) -> torch.Tensor:
        ...
from diffusers import FluxPipeline
from diffusers.guidance import ClassifierFreeGuidance, PerturbedAttentionGuidance

pipe = FluxPipeline.from_pretrained(...)
pipe.to("cuda")

cfg = ClassifierFreeGuidance(scale=7.0)
pag = PerturbedAttentionGuidance(scale=5.0, layers=["transformer_blocks\.(20|24)"])

cfg_output = pipe(..., guidance=cfg)
pag_output = pipe(..., guidance=pag)

In the existing pipelines, we will invoke the prepare_inputs and __call__ methods in a non-backwards-breaking manner. For the new modular diffusers, we can customize as required. As the guidance objects are lightweight to create, one can modify it on-the-fly, which would be super useful for UI cases and experimentation.

A pet peeve I have is needing to write additional attention processors for a method like PAG. Per model processors are hard to maintain for all kinds of techniques available, with all kinds of permutations possible. This introduces limitations. Since we know that most modeling implementations use our Attention class, or atleast follow similar naming conventions, one way of making this technique generally applicable is utilizing some sort of pre/post-forward hook that can perform the attention-branch shortcut required in PAG. This would be a single addition to address all models at once, because we follow certain strict naming conventions of layers.

As guiders can be stateful (for example, disabling guidance after certain number of steps should remove the unconditional latent/prompt embeddings, or guidance scale could be adaptive to amount of low-frequency/high-frequency noise in latent), I really like that we can do reset_guider. IMO, we should mark this as stateful/un-stateful using a flag like _is_stateful = True (similar to

).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w feel free to take over the guider and refactor it :)

@yiyixuxu
Copy link
Collaborator Author

thanks @a-r-r-o-w! insightful as always:)

Each PipelineBlock IMO should only contain minimal implementation, i.e. bare-minimum single functionality and not handle too many overlapping cases. For example, pipeline.encode_prompt and similar that operate on both unconditional and conditional branches should probably just support one option prompt. The pipeline block can then invoke this method twice - once for positive, once for negative.

If methods like encode_prompt could have a functional equivalent that can be invoked from outside a pipeline/pipeline-blocks, I think it would be super helpful for re-using in trainers instead of rolling our own minified implementation.

not sure I understand what this means "Each PipelineBlock IMO should only contain minimal implementation"; but based on the example you provided I think we are aligned. @hlky is working on a refactor on some of the pipeline methods to do just what you described. we are also considering making them class method so it can be invoked outside of pipeline/pipeline-blocks. Please take a look there and share your thoughts! #10726

We should consider batching vs non-batching inference. Currently with existing pipelines, we always batch negative and positive prompt embeds. This increase memory required from intermediate activation states by 2x. For a low VRAM mode, this might be an important consideration. (It's not very important. We can always add a BatchedInferenceHook or something to the model::forward to split the args/kwargs along batch dimension)

hook approach sounds good, or a special optimized denoising block. feel free to explore and it can be part of the offloading strategy we offer on components manager, e.g. if user does not have enough memory, we automatically run the non-batch inference.

Currently, the invocation mode is eager. Something like:

I_AM_AT_BLOCK_X -> DO_I_HAVE_THE_INPUTS_I_REQUIRE? ---> YES ---> PERFORM_COMPUTATION_AND_PROCEED_TO_NEXT_BLOCK
|--> NO ---> RAISE_ERROR
If we're somewhere deep inside the execution stage and then error out (maybe due to a missing input), all computation done till now is lost for a silly error. This is very frustrating (I've personally faced it multiple times during model integrations). IMO we have an opportunity to improve this (perhaps, some time in the near future if not for now). Since we already know that each block requires a set of inputs and outputs, regardless of what the other blocks do, we can topologically traverse the graph of blocks in reverse to determine if all inputs/outputs mapping is correct. If not, we can early-error out and let the user know. If yes, we can proceed with computation.
Note that this won't help identify issues in cases where we simply forgot to pass an input to a model or something, but it'll be helpful in block-development cases -- we're simply doing a static analysis to make sure that the invocation graph makes sense on a high-level from the pipeline one creates.

actually, we are already doing that. when combine a few pipeline blocks in a sequential order, we loop through the blocks to find out the overall intermediates_inputs with a logic similar to what you described here. see code here

def intermediates_inputs(self) -> List[str]:

basically say we have 3 blocks we want to combine in sequential order

combined_block: block1 -> block2 -> block3 

each block has inputs/intermediates_inputs/intermediates_outputs and we want to find out intermediates_inputs for combined_block

we look through the blocks,

  • for block1, all its intermediates_inputs will be added to combined_blocks's intermediates_inputs
  • for block2, we look at all of its intermediate inputs, only if they are not already not an output of block1, we add them to the intermediates_inputs of the combined_block;
  • for block3, we look at all its intermediate inputs, only if it is not already outputs for block1 and block2, we add them to the intermedaites_inputs of the combined_block

once we have this inermediates_inputs info for the combined_block, we use it to build argument list and doc string for the combined_block and help user find out what inputs to pass - e.g. for example, if block3 requires an latents inputs and none of the previous blocks output latents, latents will ended up in the intermediate_inputs field of the combined_block and user need to provide it -> the developers should also be able to deduce that it is a bug in their implementation if latents is in fact not an intended input from the user.

I think we can add a check_input function for that and throw an error if an required intermediate_inputs are not passed

happy to work a bit more on this with you! I think it is very important feature. I can start to add some test cases for things we are already covered - and you can help to see if we miss any use cases. what do you think?

Regarding _execution_device and dtype on pipeline, I think we should remove it and instead infer device/dtype from the module that is going to do the processing next. For example, if my text encoder is in float16 but transformer is in bfloat16, dtype on pipeline will return float16. So, prompt embeds will be in different dtype leading to error on transformer unless we explicitly write some logic to handle this in the pipeline. Writing it per-model block is prone to errors and can introduce lossy conversions, so it might be nice to push to keep the pipeline as a simple container holding modules and remove any notion of module state from it, and handle these device/dtype-changes more centrally (like pipeline.prepare_inputs_for_model(model, inputs)) (just my thoughts and not really at issue here)

agree, feel free to help refacter later!

Have we thought about how a pipeline created by a user can be shared via an exported file, say on the Hub, for ease of distribution?

I think we should share vis hub but haven't explored about that yet - feel free to take a stab on it!

@DN6
Copy link
Collaborator

DN6 commented Feb 19, 2025

@yiyixuxu
re: using ComponentSpec in expected_components ; Yeah I like this approach 👍🏽

one caveat is that since Pipeline blocks are "stateless", you have to follow the (pipeline, data) parameter pattern for these custom method, the same way you do for call method on the blocks. data contains runtime states (inputs/outputs/intermediates) and pipeline contains the models/config etc
methods on Pipeline level

I think this is fine.

I made them global pipeline method for one reason only: to be able to use #Copied from directly and to minimize maintenance cost for us 😛 If you prefer to move these methods to pipeline blocks level so you don't need to hop back and forth, we can totally look into that!

I do prefer that. One case can I think of is if I try to replace a step in a Pipeline e.g. the encode prompt step and then I try pipe.encode_prompt I would end up with a different output to what I might expect, since it would be using the "default" encode prompt method of the pipeline and not the one I've attempted to replace it with. Although nothing about the current implementation prevents using methods at the block level, it might be good to define some rule about the relationship between block level methods and pipeline level methods. My intuition would be that the global Pipeline methods should use the block level methods?

Related to the first point, I think it is okay to make pipeline blocks stateful (i.e., to set components in blocks). However, components need to be managed in the global pipeline scope.

I'm cool with having the components be managed at the global level. I agree it would get complicated if the components are attached to blocks. I think I was trying to convey that register_component in the block would add the component to the global context

class MyPipelineBlock:

    def __init__(self, vae):
        self.register_component(vae=vae)

vae = AutoencoderKL.from_pretrained("..")
pipe = ModularPipeline.from_block(MyPipelineBlock(vae=vae))

# these would point to the same object
pipe.vae == pipe.blocks("vae_step").vae

But the ComponentSpec solution also works for this case 👍🏽 And the other points such as not being able to use different VAE's at different steps makes sense.

If you think a stateful pipeline block design is more intuitive, I'd be happy to explore with you for that too:)

I think nothing in the current design prevents stateful blocks though. I think we need a bit more clarity on how to create/manage them correctly. e.g in the DepthBlock solution you proposed

class DepthBlock(PipelineBlock):
    expected_components = [
        ComponentSpec(
            name="depth_processor",
            class_name=["depth_anything", "DepthPreprocessor"],
            default_repo="depth-anything/Depth-Anything-V2-Large-hf"
        )
    ]
    
    @property
    def inputs(self) -> List[InputParam]:
        return [InputParam(
            name="control_image",
            required=True,
        )]

    def __call__(self, pipeline, state: PipelineState) -> PipelineState:
        data = self.get_block_state(state)
        depth_image = pipeline.depth_processor(data.control_image)
        data.control_image = depth_image
        self.add_block_state(data, state)
        return pipeline, state

If we're creating the DepthProcessor object, how do we know how to import it if it's from a different library or if the user has defined it in the same file as the DepthBlock? Through the class_name arg?

Another thought I had was, suppose a user has created a custom Pipeline Block with a model component (config and weights), and custom code and is hosting both on the Hub. Would we allows something like PipelineBlock.from_pretrained to load it in the way we do for Custom Pipelines?

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Feb 21, 2025

@DN6

I do prefer that. One case can I think of is if I try to replace a step in a Pipeline e.g. the encode prompt step and then I try pipe.encode_prompt I would end up with a different output to what I might expect, since it would be using the "default" encode prompt method of the pipeline and not the one I've attempted to replace it with. Although nothing about the current implementation prevents using methods at the block level, it might be good to define some rule about the relationship between block level methods and pipeline level methods. My intuition would be that the global Pipeline methods should use the block level methods?

Let me try to move all the global pipeline methods to the block level first - If I'm able to do that, I think maybe we won't need global pipeline methods at all, so things would be easier

If we're creating the DepthProcessor object, how do we know how to import it if it's from a different library or if the user has defined it in the same file as the DepthBlock? Through the class_name arg?

I was thinking something similar to model_index. so we would re-use same approach (potentially code) to handle. So a different library should not be problem (like we do for text_encoders from transformers); and if it's defined in the same file, maybe we can do something similar to what we do for these diffusers modules that we cannot import from top level like here .

I haven't really thought through about it, though; if you have good suggestions, let me know!!

Another thought I had was, suppose a user has created a custom Pipeline Block with a model component (config and weights), and custom code and is hosting both on the Hub. Would we allows something like PipelineBlock.from_pretrained to load it in the way we do for Custom Pipelines?

I'm not sure how custom code would work for now (like how we share the code on hub and load them), let me know if you have good ideas! but yes I think we should add a loading method to pipeline blocks!

we should also allow attaching components manager to the pipeline blocks so that things loaded from the pipeline block will be registered to the components manager

adding a loading method on pipeline blocks will also be able to support the use case you described earlier before #9672 (comment)

But if I want to load a ControlNet Model via a model repo I cannot. I have to create the object and add to Components Manager via the add method.

instead of this (we could still support this in the future after we have the AutoModel class)

components.add_from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16)

we could already do something like this:

control_block.add_from_pretrained(repo, components_manager = components)

because, we will add class info for expected_components now so control_block has all the info to load the model, and we can automatically register to the components manager for memory optimization later

@exdysa
Copy link

exdysa commented Feb 21, 2025

Hi. I've been watching this project unfold for some time now and I've attempted some orthogonal modular diffusers projects in the past. I’m deeply interested in reviewing, researching, responding, reciprocating etc during creation process if possible, especially with regards to integrating Auto Workflow elements, expected_components and ComponentSpec and the previously described “magic” to close the barrier between ‘us’ (who I presume to be developers) and the ‘ui/professionals’ (who I presume to be non-developers/the coding disinclined). I have a lot to say but I want to read and learn more beforehand, so I'm curious how and where to ask.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
Status: Future Release
Development

Successfully merging this pull request may close these issues.