Skip to content

Commit f95a0f0

Browse files
committed
Assorted fixes and addressing comments
1 parent 85b3037 commit f95a0f0

File tree

6 files changed

+28
-32
lines changed

6 files changed

+28
-32
lines changed

sharktank/sharktank/models/vae/layers.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,12 @@ def __init__(
197197
attn_groups: Optional[int] = None,
198198
):
199199
super().__init__(theta)
200-
attentions = []
201200

202201
resnet_groups = resnet_groups if resnet_time_scale_shift == "default" else None
203202

204203
# there is always at least one resnet
205204
if resnet_time_scale_shift == "spatial":
206-
# TODO
205+
# TODO Implement ResnetBlockCondNorm2d block for spatial time scale shift
207206
raise AssertionError(f"ResnetBlockCondNorm2d not yet implemented")
208207
else:
209208
resnets = [
@@ -218,24 +217,22 @@ def __init__(
218217
dropout=dropout,
219218
)
220219
]
221-
for _ in range(num_layers):
222-
if add_attention:
223-
attentions.append(
224-
AttentionLayer(
225-
theta("attentions")(0),
226-
heads=1,
227-
dim_head=attention_head_dim,
228-
rescale_output_factor=1.0,
229-
eps=resnet_eps,
230-
norm_num_groups=attn_groups,
231-
residual_connection=True,
232-
)
233-
)
234-
else:
235-
attentions.append(None)
220+
# TODO: loop through num_layers properly. Works for sdxl vae specifically but removed for export reasons
221+
if add_attention:
222+
self.attention = AttentionLayer(
223+
theta("attentions")(0),
224+
heads=1,
225+
dim_head=attention_head_dim,
226+
rescale_output_factor=1.0,
227+
eps=resnet_eps,
228+
norm_num_groups=attn_groups,
229+
residual_connection=True,
230+
)
231+
else:
232+
self.attention = None
236233

237234
if resnet_time_scale_shift == "spatial":
238-
# TODO
235+
# TODO Implement ResnetBlock2D for spatial time scale shift support
239236
raise AssertionError(
240237
f"ResnetBlock2D spatial time scale shift not yet implemented"
241238
)
@@ -252,16 +249,13 @@ def __init__(
252249
dropout=dropout,
253250
)
254251
)
255-
self.attentions = nn.ModuleList(attentions)
256-
self.resnets = resnets
252+
self.resnets = nn.ModuleList(resnets)
257253

258254
def forward(
259255
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None
260256
) -> torch.Tensor:
261257
hidden_states = self.resnets[0](hidden_states, temb)
262-
for attn, resnet in zip(self.attentions, self.resnets[1:]):
263-
if attn is not None:
264-
hidden_states = attn(hidden_states)
265-
hidden_states = resnet(hidden_states, temb)
266-
258+
if self.attention is not None:
259+
hidden_states = self.attention(hidden_states)
260+
hidden_states = self.resnets[1](hidden_states, temb)
267261
return hidden_states

sharktank/sharktank/models/vae/tools/diffuser_ref.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
)
4545

4646
def decode(self, inp):
47+
# The reference vae decode does not do scaling and leaves it for the sdxl pipeline. We integrate it into vae for pipeline performance so using the hardcoded values from the config.json here
4748
img = 1 / 0.13025 * inp
4849
x = self.vae.decode(img, return_dict=False)[0]
4950
return (x / 2 + 0.5).clamp(0, 1)

sharktank/sharktank/models/vae/tools/run_vae.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ def main(argv):
103103
inputs = get_random_inputs(dtype=dtype, device=device, bs=args.bs)
104104

105105
if args.export:
106+
# TODO move export from a run_vae file
106107
output = export_vae(mdl, inputs, args.decomp_attn)
107108
output.save_mlir(args.export)
109+
print("exported VAE model. Skipping eager execution")
108110
else:
109111
# Save intermediates.
110112
intermediates_saver = None

sharktank/sharktank/models/vae/tools/sample_data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313

1414
def get_random_inputs(dtype, device, bs: int = 2):
15-
torch.random.manual_seed(42)
1615
height = 1024
1716
width = 1024
1817
return torch.rand(bs, 4, width // 8, height // 8, dtype=dtype).to(device)

sharktank/sharktank/ops/signatures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@
5555
"sharded_cat",
5656
"sharded_sum",
5757
"softmax",
58+
"squeeze",
5859
"to",
5960
"transfer_to_logical_device",
6061
"transpose",
6162
"unflatten",
6263
"unshard",
6364
"unsqueeze",
64-
"squeeze",
6565
"view",
6666
"view_as_complex",
6767
"view_as_real",

sharktank/sharktank/types/tensors.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,11 @@ def size(self, dim: Optional[int] = None) -> tuple[int]:
382382
if dim is None:
383383
return tuple(self.shape)
384384
return self.shape[dim]
385+
386+
def squeeze(self, dim: Optional[int] = None) -> "AnyTensor":
387+
from ..ops import squeeze
388+
389+
return squeeze(self, dim)
385390

386391
def transpose(self, dim0: int, dim1: int) -> "AnyTensor":
387392
from ..ops import transpose
@@ -398,11 +403,6 @@ def unsqueeze(self, dim: int) -> "AnyTensor":
398403

399404
return unsqueeze(self, dim)
400405

401-
def squeeze(self, dim: Optional[int] = None) -> "AnyTensor":
402-
from ..ops import squeeze
403-
404-
return squeeze(self, dim)
405-
406406
def view(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor":
407407
from ..ops import view
408408

0 commit comments

Comments
 (0)