Skip to content

Commit 6c7fad7

Browse files
authored
Add community class StableDiffusionXL_T5Pipeline (#11626)
* Add community class StableDiffusionXL_T5Pipeline Will be used with base model opendiffusionai/stablediffusionxl_t5 * Changed pooled_embeds to use projection instead of slice * "make style" tweaks * Added comments to top of code * Apply style fixes
1 parent 5b0dab1 commit 6c7fad7

File tree

1 file changed

+205
-0
lines changed

1 file changed

+205
-0
lines changed
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright Philip Brown, ppbrown@github
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
###########################################################################
16+
# This pipeline attempts to use a model that has SDXL vae, T5 text encoder,
17+
# and SDXL unet.
18+
# At the present time, there are no pretrained models that give pleasing
19+
# output. So as yet, (2025/06/10) this pipeline is somewhat of a tech
20+
# demo proving that the pieces can at least be put together.
21+
# Hopefully, it will encourage someone with the hardware available to
22+
# throw enough resources into training one up.
23+
24+
25+
from typing import Optional
26+
27+
import torch.nn as nn
28+
from transformers import (
29+
CLIPImageProcessor,
30+
CLIPTokenizer,
31+
CLIPVisionModelWithProjection,
32+
T5EncoderModel,
33+
)
34+
35+
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
36+
from diffusers.image_processor import VaeImageProcessor
37+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
38+
from diffusers.schedulers import KarrasDiffusionSchedulers
39+
40+
41+
# Note: At this time, the intent is to use the T5 encoder mentioned
42+
# below, with zero changes.
43+
# Therefore, the model deliberately does not store the T5 encoder model bytes,
44+
# (Since they are not unique!)
45+
# but instead takes advantage of huggingface hub cache loading
46+
47+
T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly"
48+
49+
# Caller is expected to load this, or equivalent, as model name for now
50+
# eg: pipe = StableDiffusionXL_T5Pipeline(SDXL_NAME)
51+
SDXL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
52+
53+
54+
class LinearWithDtype(nn.Linear):
55+
@property
56+
def dtype(self):
57+
return self.weight.dtype
58+
59+
60+
class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline):
61+
_expected_modules = [
62+
"vae",
63+
"unet",
64+
"scheduler",
65+
"tokenizer",
66+
"image_encoder",
67+
"feature_extractor",
68+
"t5_encoder",
69+
"t5_projection",
70+
"t5_pooled_projection",
71+
]
72+
73+
_optional_components = [
74+
"image_encoder",
75+
"feature_extractor",
76+
"t5_encoder",
77+
"t5_projection",
78+
"t5_pooled_projection",
79+
]
80+
81+
def __init__(
82+
self,
83+
vae: AutoencoderKL,
84+
unet: UNet2DConditionModel,
85+
scheduler: KarrasDiffusionSchedulers,
86+
tokenizer: CLIPTokenizer,
87+
t5_encoder=None,
88+
t5_projection=None,
89+
t5_pooled_projection=None,
90+
image_encoder: CLIPVisionModelWithProjection = None,
91+
feature_extractor: CLIPImageProcessor = None,
92+
force_zeros_for_empty_prompt: bool = True,
93+
add_watermarker: Optional[bool] = None,
94+
):
95+
DiffusionPipeline.__init__(self)
96+
97+
if t5_encoder is None:
98+
self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME, torch_dtype=unet.dtype)
99+
else:
100+
self.t5_encoder = t5_encoder
101+
102+
# ----- build T5 4096 => 2048 dim projection -----
103+
if t5_projection is None:
104+
self.t5_projection = LinearWithDtype(4096, 2048) # trainable
105+
else:
106+
self.t5_projection = t5_projection
107+
self.t5_projection.to(dtype=unet.dtype)
108+
# ----- build T5 4096 => 1280 dim projection -----
109+
if t5_pooled_projection is None:
110+
self.t5_pooled_projection = LinearWithDtype(4096, 1280) # trainable
111+
else:
112+
self.t5_pooled_projection = t5_pooled_projection
113+
self.t5_pooled_projection.to(dtype=unet.dtype)
114+
115+
print("dtype of Linear is ", self.t5_projection.dtype)
116+
117+
self.register_modules(
118+
vae=vae,
119+
unet=unet,
120+
scheduler=scheduler,
121+
tokenizer=tokenizer,
122+
t5_encoder=self.t5_encoder,
123+
t5_projection=self.t5_projection,
124+
t5_pooled_projection=self.t5_pooled_projection,
125+
image_encoder=image_encoder,
126+
feature_extractor=feature_extractor,
127+
)
128+
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
129+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
130+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
131+
132+
self.default_sample_size = (
133+
self.unet.config.sample_size
134+
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
135+
else 128
136+
)
137+
138+
self.watermark = None
139+
140+
# Parts of original SDXL class complain if these attributes are not
141+
# at least PRESENT
142+
self.text_encoder = self.text_encoder_2 = None
143+
144+
# ------------------------------------------------------------------
145+
# Encode a text prompt (T5-XXL + 4096→2048 projection)
146+
# Returns exactly four tensors in the order SDXL’s __call__ expects.
147+
# ------------------------------------------------------------------
148+
def encode_prompt(
149+
self,
150+
prompt,
151+
num_images_per_prompt: int = 1,
152+
do_classifier_free_guidance: bool = True,
153+
negative_prompt: str | None = None,
154+
**_,
155+
):
156+
"""
157+
Returns
158+
-------
159+
prompt_embeds : Tensor [B, T, 2048]
160+
negative_prompt_embeds : Tensor [B, T, 2048] | None
161+
pooled_prompt_embeds : Tensor [B, 1280]
162+
negative_pooled_prompt_embeds: Tensor [B, 1280] | None
163+
where B = batch * num_images_per_prompt
164+
"""
165+
166+
# --- helper to tokenize on the pipeline’s device ----------------
167+
def _tok(text: str):
168+
tok_out = self.tokenizer(
169+
text,
170+
return_tensors="pt",
171+
padding="max_length",
172+
max_length=self.tokenizer.model_max_length,
173+
truncation=True,
174+
).to(self.device)
175+
return tok_out.input_ids, tok_out.attention_mask
176+
177+
# ---------- positive stream -------------------------------------
178+
ids, mask = _tok(prompt)
179+
h_pos = self.t5_encoder(ids, attention_mask=mask).last_hidden_state # [b, T, 4096]
180+
tok_pos = self.t5_projection(h_pos) # [b, T, 2048]
181+
pool_pos = self.t5_pooled_projection(h_pos.mean(dim=1)) # [b, 1280]
182+
183+
# expand for multiple images per prompt
184+
tok_pos = tok_pos.repeat_interleave(num_images_per_prompt, 0)
185+
pool_pos = pool_pos.repeat_interleave(num_images_per_prompt, 0)
186+
187+
# ---------- negative / CFG stream --------------------------------
188+
if do_classifier_free_guidance:
189+
neg_text = "" if negative_prompt is None else negative_prompt
190+
ids_n, mask_n = _tok(neg_text)
191+
h_neg = self.t5_encoder(ids_n, attention_mask=mask_n).last_hidden_state
192+
tok_neg = self.t5_projection(h_neg)
193+
pool_neg = self.t5_pooled_projection(h_neg.mean(dim=1))
194+
195+
tok_neg = tok_neg.repeat_interleave(num_images_per_prompt, 0)
196+
pool_neg = pool_neg.repeat_interleave(num_images_per_prompt, 0)
197+
else:
198+
tok_neg = pool_neg = None
199+
200+
# ----------------- final ordered return --------------------------
201+
# 1) positive token embeddings
202+
# 2) negative token embeddings (or None)
203+
# 3) positive pooled embeddings
204+
# 4) negative pooled embeddings (or None)
205+
return tok_pos, tok_neg, pool_pos, pool_neg

0 commit comments

Comments
 (0)