|
| 1 | +# Copyright Philip Brown, ppbrown@github |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +########################################################################### |
| 16 | +# This pipeline attempts to use a model that has SDXL vae, T5 text encoder, |
| 17 | +# and SDXL unet. |
| 18 | +# At the present time, there are no pretrained models that give pleasing |
| 19 | +# output. So as yet, (2025/06/10) this pipeline is somewhat of a tech |
| 20 | +# demo proving that the pieces can at least be put together. |
| 21 | +# Hopefully, it will encourage someone with the hardware available to |
| 22 | +# throw enough resources into training one up. |
| 23 | + |
| 24 | + |
| 25 | +from typing import Optional |
| 26 | + |
| 27 | +import torch.nn as nn |
| 28 | +from transformers import ( |
| 29 | + CLIPImageProcessor, |
| 30 | + CLIPTokenizer, |
| 31 | + CLIPVisionModelWithProjection, |
| 32 | + T5EncoderModel, |
| 33 | +) |
| 34 | + |
| 35 | +from diffusers import DiffusionPipeline, StableDiffusionXLPipeline |
| 36 | +from diffusers.image_processor import VaeImageProcessor |
| 37 | +from diffusers.models import AutoencoderKL, UNet2DConditionModel |
| 38 | +from diffusers.schedulers import KarrasDiffusionSchedulers |
| 39 | + |
| 40 | + |
| 41 | +# Note: At this time, the intent is to use the T5 encoder mentioned |
| 42 | +# below, with zero changes. |
| 43 | +# Therefore, the model deliberately does not store the T5 encoder model bytes, |
| 44 | +# (Since they are not unique!) |
| 45 | +# but instead takes advantage of huggingface hub cache loading |
| 46 | + |
| 47 | +T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly" |
| 48 | + |
| 49 | +# Caller is expected to load this, or equivalent, as model name for now |
| 50 | +# eg: pipe = StableDiffusionXL_T5Pipeline(SDXL_NAME) |
| 51 | +SDXL_NAME = "stabilityai/stable-diffusion-xl-base-1.0" |
| 52 | + |
| 53 | + |
| 54 | +class LinearWithDtype(nn.Linear): |
| 55 | + @property |
| 56 | + def dtype(self): |
| 57 | + return self.weight.dtype |
| 58 | + |
| 59 | + |
| 60 | +class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline): |
| 61 | + _expected_modules = [ |
| 62 | + "vae", |
| 63 | + "unet", |
| 64 | + "scheduler", |
| 65 | + "tokenizer", |
| 66 | + "image_encoder", |
| 67 | + "feature_extractor", |
| 68 | + "t5_encoder", |
| 69 | + "t5_projection", |
| 70 | + "t5_pooled_projection", |
| 71 | + ] |
| 72 | + |
| 73 | + _optional_components = [ |
| 74 | + "image_encoder", |
| 75 | + "feature_extractor", |
| 76 | + "t5_encoder", |
| 77 | + "t5_projection", |
| 78 | + "t5_pooled_projection", |
| 79 | + ] |
| 80 | + |
| 81 | + def __init__( |
| 82 | + self, |
| 83 | + vae: AutoencoderKL, |
| 84 | + unet: UNet2DConditionModel, |
| 85 | + scheduler: KarrasDiffusionSchedulers, |
| 86 | + tokenizer: CLIPTokenizer, |
| 87 | + t5_encoder=None, |
| 88 | + t5_projection=None, |
| 89 | + t5_pooled_projection=None, |
| 90 | + image_encoder: CLIPVisionModelWithProjection = None, |
| 91 | + feature_extractor: CLIPImageProcessor = None, |
| 92 | + force_zeros_for_empty_prompt: bool = True, |
| 93 | + add_watermarker: Optional[bool] = None, |
| 94 | + ): |
| 95 | + DiffusionPipeline.__init__(self) |
| 96 | + |
| 97 | + if t5_encoder is None: |
| 98 | + self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME, torch_dtype=unet.dtype) |
| 99 | + else: |
| 100 | + self.t5_encoder = t5_encoder |
| 101 | + |
| 102 | + # ----- build T5 4096 => 2048 dim projection ----- |
| 103 | + if t5_projection is None: |
| 104 | + self.t5_projection = LinearWithDtype(4096, 2048) # trainable |
| 105 | + else: |
| 106 | + self.t5_projection = t5_projection |
| 107 | + self.t5_projection.to(dtype=unet.dtype) |
| 108 | + # ----- build T5 4096 => 1280 dim projection ----- |
| 109 | + if t5_pooled_projection is None: |
| 110 | + self.t5_pooled_projection = LinearWithDtype(4096, 1280) # trainable |
| 111 | + else: |
| 112 | + self.t5_pooled_projection = t5_pooled_projection |
| 113 | + self.t5_pooled_projection.to(dtype=unet.dtype) |
| 114 | + |
| 115 | + print("dtype of Linear is ", self.t5_projection.dtype) |
| 116 | + |
| 117 | + self.register_modules( |
| 118 | + vae=vae, |
| 119 | + unet=unet, |
| 120 | + scheduler=scheduler, |
| 121 | + tokenizer=tokenizer, |
| 122 | + t5_encoder=self.t5_encoder, |
| 123 | + t5_projection=self.t5_projection, |
| 124 | + t5_pooled_projection=self.t5_pooled_projection, |
| 125 | + image_encoder=image_encoder, |
| 126 | + feature_extractor=feature_extractor, |
| 127 | + ) |
| 128 | + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) |
| 129 | + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 |
| 130 | + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) |
| 131 | + |
| 132 | + self.default_sample_size = ( |
| 133 | + self.unet.config.sample_size |
| 134 | + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") |
| 135 | + else 128 |
| 136 | + ) |
| 137 | + |
| 138 | + self.watermark = None |
| 139 | + |
| 140 | + # Parts of original SDXL class complain if these attributes are not |
| 141 | + # at least PRESENT |
| 142 | + self.text_encoder = self.text_encoder_2 = None |
| 143 | + |
| 144 | + # ------------------------------------------------------------------ |
| 145 | + # Encode a text prompt (T5-XXL + 4096→2048 projection) |
| 146 | + # Returns exactly four tensors in the order SDXL’s __call__ expects. |
| 147 | + # ------------------------------------------------------------------ |
| 148 | + def encode_prompt( |
| 149 | + self, |
| 150 | + prompt, |
| 151 | + num_images_per_prompt: int = 1, |
| 152 | + do_classifier_free_guidance: bool = True, |
| 153 | + negative_prompt: str | None = None, |
| 154 | + **_, |
| 155 | + ): |
| 156 | + """ |
| 157 | + Returns |
| 158 | + ------- |
| 159 | + prompt_embeds : Tensor [B, T, 2048] |
| 160 | + negative_prompt_embeds : Tensor [B, T, 2048] | None |
| 161 | + pooled_prompt_embeds : Tensor [B, 1280] |
| 162 | + negative_pooled_prompt_embeds: Tensor [B, 1280] | None |
| 163 | + where B = batch * num_images_per_prompt |
| 164 | + """ |
| 165 | + |
| 166 | + # --- helper to tokenize on the pipeline’s device ---------------- |
| 167 | + def _tok(text: str): |
| 168 | + tok_out = self.tokenizer( |
| 169 | + text, |
| 170 | + return_tensors="pt", |
| 171 | + padding="max_length", |
| 172 | + max_length=self.tokenizer.model_max_length, |
| 173 | + truncation=True, |
| 174 | + ).to(self.device) |
| 175 | + return tok_out.input_ids, tok_out.attention_mask |
| 176 | + |
| 177 | + # ---------- positive stream ------------------------------------- |
| 178 | + ids, mask = _tok(prompt) |
| 179 | + h_pos = self.t5_encoder(ids, attention_mask=mask).last_hidden_state # [b, T, 4096] |
| 180 | + tok_pos = self.t5_projection(h_pos) # [b, T, 2048] |
| 181 | + pool_pos = self.t5_pooled_projection(h_pos.mean(dim=1)) # [b, 1280] |
| 182 | + |
| 183 | + # expand for multiple images per prompt |
| 184 | + tok_pos = tok_pos.repeat_interleave(num_images_per_prompt, 0) |
| 185 | + pool_pos = pool_pos.repeat_interleave(num_images_per_prompt, 0) |
| 186 | + |
| 187 | + # ---------- negative / CFG stream -------------------------------- |
| 188 | + if do_classifier_free_guidance: |
| 189 | + neg_text = "" if negative_prompt is None else negative_prompt |
| 190 | + ids_n, mask_n = _tok(neg_text) |
| 191 | + h_neg = self.t5_encoder(ids_n, attention_mask=mask_n).last_hidden_state |
| 192 | + tok_neg = self.t5_projection(h_neg) |
| 193 | + pool_neg = self.t5_pooled_projection(h_neg.mean(dim=1)) |
| 194 | + |
| 195 | + tok_neg = tok_neg.repeat_interleave(num_images_per_prompt, 0) |
| 196 | + pool_neg = pool_neg.repeat_interleave(num_images_per_prompt, 0) |
| 197 | + else: |
| 198 | + tok_neg = pool_neg = None |
| 199 | + |
| 200 | + # ----------------- final ordered return -------------------------- |
| 201 | + # 1) positive token embeddings |
| 202 | + # 2) negative token embeddings (or None) |
| 203 | + # 3) positive pooled embeddings |
| 204 | + # 4) negative pooled embeddings (or None) |
| 205 | + return tok_pos, tok_neg, pool_pos, pool_neg |
0 commit comments