Skip to content

Commit 85b3037

Browse files
committed
Adds sdxl's VAE decoder implementation
1 parent c9cb226 commit 85b3037

File tree

13 files changed

+935
-1
lines changed

13 files changed

+935
-1
lines changed

sharktank/requirements-tests.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ datasets==3.0.0
22
parameterized
33
pytest==8.0.0
44
pytest-html
5+
diffusers

sharktank/sharktank/models/punet/layers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,6 @@ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor) -> torch.Tenso
571571
hidden_states = ops.elementwise(self.nonlinearity, hidden_states)
572572
hidden_states = self.conv1(hidden_states)
573573

574-
assert self.time_emb_proj is not None
575574
if self.time_emb_proj is not None:
576575
temb = ops.elementwise(self.nonlinearity, temb)
577576
temb = self.time_emb_proj(temb)[:, :, None, None]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# VAE decoder
2+
3+
This is vae implemented in the style used for SDXL and referenced from diffusers implementation.
4+
5+
## Preparing dataset
6+
If not sharding or quantizing, the official model can be imported as from huggingface:
7+
8+
```
9+
model_dir=$(huggingface-cli download \
10+
stabilityai/stable-diffusion-xl-base-1.0 \
11+
vae/config.json vae/diffusion_pytorch_model.safetensors)
12+
13+
python -m sharktank.models.punet.tools.import_hf_dataset \
14+
--params $model_dir/vae/diffusion_pytorch_model.safetensors
15+
--config-json $model_dir/vae/config.json --output-irpa-file ~/models/vae.irpa
16+
```
17+
18+
# Run Vae decoder model eager mode
19+
```
20+
python -m sharktank.models.vae.tools.run_vae --irpa-file ~/models/vae.irpa --device cpu
21+
```
22+
23+
## License
24+
25+
Significant portions of this implementation were derived from diffusers,
26+
licensed under Apache2: https://github.com/huggingface/diffusers
27+
While much was a simple reverse engineering of the config.json and parameters,
28+
code was taken where appropriate.
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
#
7+
# Significant portions of this implementation were derived from diffusers,
8+
# licensed under Apache2: https://github.com/huggingface/diffusers
9+
# While much was a simple reverse engineering of the config.json and parameters,
10+
# code was taken where appropriate.
11+
12+
from typing import List, Optional, Sequence, Tuple, Union
13+
14+
from dataclasses import dataclass
15+
import inspect
16+
import warnings
17+
18+
__all__ = [
19+
"HParams",
20+
]
21+
22+
23+
@dataclass
24+
class HParams:
25+
# Per block sequences. These are normalized from either an int (dubplicated
26+
# to the number of down_blocks or a list.
27+
layers_per_block: Tuple[int]
28+
29+
act_fn: str = "silu"
30+
block_out_channels: Sequence[int] = (128, 256, 512, 512)
31+
in_channels: int = 3
32+
up_block_types: Sequence[str] = (
33+
"UpDecoderBlock2D",
34+
"UpDecoderBlock2D",
35+
"UpDecoderBlock2D",
36+
"UpDecoderBlock2D",
37+
)
38+
layers_per_block: int = 2
39+
norm_num_groups: int = 32
40+
scaling_factor: float = 0.13025
41+
42+
def assert_default_values(self, attr_names: Sequence[str]):
43+
for name in attr_names:
44+
actual = getattr(self, name)
45+
required = getattr(HParams, name)
46+
if actual != required:
47+
raise ValueError(
48+
f"NYI: HParams.{name} != {required!r} (got {actual!r})"
49+
)
50+
51+
@classmethod
52+
def from_dict(cls, d: dict):
53+
if "layers_per_block" not in d:
54+
d["layers_per_block"] = 2
55+
56+
allowed = inspect.signature(cls).parameters
57+
declared_kwargs = {k: v for k, v in d.items() if k in allowed}
58+
extra_kwargs = [k for k in d.keys() if k not in allowed]
59+
if extra_kwargs:
60+
warnings.warn(f"Unhandled vae.HParams: {extra_kwargs}")
61+
return cls(**declared_kwargs)
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
from typing import Optional, Sequence, Tuple
7+
8+
import math
9+
10+
import torch
11+
import torch.nn as nn
12+
13+
from sharktank import ops
14+
from sharktank.layers import *
15+
from sharktank.types import *
16+
from sharktank.models.punet.layers import (
17+
ResnetBlock2D,
18+
Upsample2D,
19+
GroupNormLayer,
20+
AttentionLayer,
21+
)
22+
from .config import *
23+
24+
25+
__all__ = ["UNetMidBlock2D", "UpDecoderBlock2D", "AttentionLayer"]
26+
27+
# TODO Remove and integrate with punet AttentionLayer
28+
class AttentionLayer(ThetaLayer):
29+
def __init__(
30+
self,
31+
theta: Theta,
32+
heads: int, # in_channels // attention_head_dim
33+
dim_head,
34+
rescale_output_factor: float,
35+
eps: float,
36+
norm_num_groups: int,
37+
residual_connection: bool,
38+
):
39+
super().__init__(theta)
40+
self.heads = heads
41+
self.rescale_output_factor = rescale_output_factor
42+
self.residual_connection = residual_connection
43+
44+
if norm_num_groups is not None:
45+
self.group_norm = GroupNormLayer(
46+
theta("group_norm"), num_groups=norm_num_groups, eps=eps
47+
)
48+
else:
49+
self.group_norm = None
50+
51+
self.norm_q = None
52+
self.norm_k = None
53+
54+
self.norm_cross = None
55+
self.to_q = LinearLayer(theta("to_q"))
56+
self.to_k = LinearLayer(theta("to_k"))
57+
self.to_v = LinearLayer(theta("to_v"))
58+
59+
self.added_proj_bias = True
60+
self.to_out = LinearLayer(theta("to_out")(0))
61+
62+
def forward(
63+
self,
64+
hidden_states: torch.Tensor,
65+
encoder_hidden_states: Optional[torch.Tensor] = None,
66+
attention_mask: Optional[torch.Tensor] = None,
67+
):
68+
residual = hidden_states
69+
70+
input_ndim = hidden_states.ndim
71+
if input_ndim == 4:
72+
batch_size, channel, height, width = hidden_states.shape
73+
hidden_states = hidden_states.view(
74+
batch_size, channel, height * width
75+
).transpose(1, 2)
76+
77+
batch_size, sequence_length, _ = (
78+
hidden_states.shape
79+
if encoder_hidden_states is None
80+
else encoder_hidden_states.shape
81+
)
82+
83+
if self.group_norm is not None:
84+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(
85+
1, 2
86+
)
87+
88+
query = self.to_q(hidden_states)
89+
90+
if encoder_hidden_states is None:
91+
encoder_hidden_states = hidden_states
92+
93+
key = self.to_k(encoder_hidden_states)
94+
value = self.to_v(encoder_hidden_states)
95+
inner_dim = key.shape[-1]
96+
head_dim = inner_dim // self.heads
97+
98+
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
99+
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
100+
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
101+
102+
if self.norm_q is not None:
103+
query = self.norm_q(query)
104+
if self.norm_k is not None:
105+
key = self.norm_k(key)
106+
107+
hidden_states = ops.scaled_dot_product_attention(
108+
query, key, value, a=attention_mask
109+
)
110+
hidden_states = hidden_states.transpose(1, 2).reshape(
111+
batch_size, -1, self.heads * head_dim
112+
)
113+
114+
# linear proj
115+
hidden_states = self.to_out(hidden_states)
116+
117+
if input_ndim == 4:
118+
hidden_states = hidden_states.transpose(-1, -2).reshape(
119+
batch_size, channel, height, width
120+
)
121+
122+
if self.residual_connection:
123+
hidden_states = hidden_states + residual
124+
125+
hidden_states = hidden_states / self.rescale_output_factor
126+
return hidden_states
127+
128+
129+
class UpDecoderBlock2D(ThetaLayer):
130+
def __init__(
131+
self,
132+
theta: Theta,
133+
*,
134+
num_layers: int,
135+
resnet_eps: float,
136+
resnet_act_fn: str,
137+
resnet_groups: int,
138+
resnet_out_scale_factor: Optional[float],
139+
resnet_time_scale_shift: str,
140+
temb_channels: int,
141+
dropout: float,
142+
add_upsample: bool,
143+
):
144+
super().__init__(theta)
145+
resnets = []
146+
147+
for i in range(num_layers):
148+
resnets.append(
149+
ResnetBlock2D(
150+
theta("resnets")(i),
151+
groups=resnet_groups,
152+
eps=resnet_eps,
153+
non_linearity=resnet_act_fn,
154+
output_scale_factor=resnet_out_scale_factor,
155+
time_embedding_norm=resnet_time_scale_shift,
156+
temb_channels=temb_channels,
157+
dropout=dropout,
158+
)
159+
)
160+
self.resnets = nn.ModuleList(resnets)
161+
if add_upsample:
162+
self.upsamplers = nn.ModuleList(
163+
[Upsample2D(theta("upsamplers")("0"), padding=1)]
164+
)
165+
else:
166+
self.upsamplers = None
167+
168+
def forward(
169+
self,
170+
hidden_states: torch.Tensor,
171+
temb: Optional[torch.Tensor] = None,
172+
upsample_size: Optional[int] = None,
173+
) -> torch.Tensor:
174+
for resnet in self.resnets:
175+
hidden_states = resnet(hidden_states, temb=temb)
176+
if self.upsamplers is not None:
177+
for upsampler in self.upsamplers:
178+
hidden_states = upsampler(hidden_states, upsample_size)
179+
return hidden_states
180+
181+
182+
class UNetMidBlock2D(ThetaLayer):
183+
def __init__(
184+
self,
185+
theta: Theta,
186+
temb_channels: int,
187+
dropout: float,
188+
num_layers: int,
189+
resnet_eps: float,
190+
resnet_time_scale_shift: str,
191+
resnet_act_fn: str,
192+
resnet_groups: int,
193+
resnet_pre_norm: bool,
194+
add_attention: bool,
195+
attention_head_dim: int,
196+
output_scale_factor: float,
197+
attn_groups: Optional[int] = None,
198+
):
199+
super().__init__(theta)
200+
attentions = []
201+
202+
resnet_groups = resnet_groups if resnet_time_scale_shift == "default" else None
203+
204+
# there is always at least one resnet
205+
if resnet_time_scale_shift == "spatial":
206+
# TODO
207+
raise AssertionError(f"ResnetBlockCondNorm2d not yet implemented")
208+
else:
209+
resnets = [
210+
ResnetBlock2D(
211+
theta("resnets")(0),
212+
groups=resnet_groups,
213+
eps=resnet_eps,
214+
non_linearity=resnet_act_fn,
215+
output_scale_factor=output_scale_factor,
216+
time_embedding_norm=resnet_time_scale_shift,
217+
temb_channels=temb_channels,
218+
dropout=dropout,
219+
)
220+
]
221+
for _ in range(num_layers):
222+
if add_attention:
223+
attentions.append(
224+
AttentionLayer(
225+
theta("attentions")(0),
226+
heads=1,
227+
dim_head=attention_head_dim,
228+
rescale_output_factor=1.0,
229+
eps=resnet_eps,
230+
norm_num_groups=attn_groups,
231+
residual_connection=True,
232+
)
233+
)
234+
else:
235+
attentions.append(None)
236+
237+
if resnet_time_scale_shift == "spatial":
238+
# TODO
239+
raise AssertionError(
240+
f"ResnetBlock2D spatial time scale shift not yet implemented"
241+
)
242+
else:
243+
resnets.append(
244+
ResnetBlock2D(
245+
theta("resnets")(1),
246+
groups=resnet_groups,
247+
eps=resnet_eps,
248+
non_linearity=resnet_act_fn,
249+
output_scale_factor=output_scale_factor,
250+
time_embedding_norm=resnet_time_scale_shift,
251+
temb_channels=temb_channels,
252+
dropout=dropout,
253+
)
254+
)
255+
self.attentions = nn.ModuleList(attentions)
256+
self.resnets = resnets
257+
258+
def forward(
259+
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None
260+
) -> torch.Tensor:
261+
hidden_states = self.resnets[0](hidden_states, temb)
262+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
263+
if attn is not None:
264+
hidden_states = attn(hidden_states)
265+
hidden_states = resnet(hidden_states, temb)
266+
267+
return hidden_states

0 commit comments

Comments
 (0)