Skip to content

Commit

Permalink
Assorted fixes and addressing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
IanNod committed Dec 9, 2024
1 parent 85b3037 commit f95a0f0
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 32 deletions.
44 changes: 19 additions & 25 deletions sharktank/sharktank/models/vae/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,12 @@ def __init__(
attn_groups: Optional[int] = None,
):
super().__init__(theta)
attentions = []

resnet_groups = resnet_groups if resnet_time_scale_shift == "default" else None

# there is always at least one resnet
if resnet_time_scale_shift == "spatial":
# TODO
# TODO Implement ResnetBlockCondNorm2d block for spatial time scale shift
raise AssertionError(f"ResnetBlockCondNorm2d not yet implemented")
else:
resnets = [
Expand All @@ -218,24 +217,22 @@ def __init__(
dropout=dropout,
)
]
for _ in range(num_layers):
if add_attention:
attentions.append(
AttentionLayer(
theta("attentions")(0),
heads=1,
dim_head=attention_head_dim,
rescale_output_factor=1.0,
eps=resnet_eps,
norm_num_groups=attn_groups,
residual_connection=True,
)
)
else:
attentions.append(None)
# TODO: loop through num_layers properly. Works for sdxl vae specifically but removed for export reasons
if add_attention:
self.attention = AttentionLayer(
theta("attentions")(0),
heads=1,
dim_head=attention_head_dim,
rescale_output_factor=1.0,
eps=resnet_eps,
norm_num_groups=attn_groups,
residual_connection=True,
)
else:
self.attention = None

if resnet_time_scale_shift == "spatial":
# TODO
# TODO Implement ResnetBlock2D for spatial time scale shift support
raise AssertionError(
f"ResnetBlock2D spatial time scale shift not yet implemented"
)
Expand All @@ -252,16 +249,13 @@ def __init__(
dropout=dropout,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = resnets
self.resnets = nn.ModuleList(resnets)

def forward(
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None
) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, temb)

if self.attention is not None:
hidden_states = self.attention(hidden_states)
hidden_states = self.resnets[1](hidden_states, temb)
return hidden_states
1 change: 1 addition & 0 deletions sharktank/sharktank/models/vae/tools/diffuser_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
)

def decode(self, inp):
# The reference vae decode does not do scaling and leaves it for the sdxl pipeline. We integrate it into vae for pipeline performance so using the hardcoded values from the config.json here
img = 1 / 0.13025 * inp
x = self.vae.decode(img, return_dict=False)[0]
return (x / 2 + 0.5).clamp(0, 1)
Expand Down
2 changes: 2 additions & 0 deletions sharktank/sharktank/models/vae/tools/run_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ def main(argv):
inputs = get_random_inputs(dtype=dtype, device=device, bs=args.bs)

if args.export:
# TODO move export from a run_vae file
output = export_vae(mdl, inputs, args.decomp_attn)
output.save_mlir(args.export)
print("exported VAE model. Skipping eager execution")
else:
# Save intermediates.
intermediates_saver = None
Expand Down
1 change: 0 additions & 1 deletion sharktank/sharktank/models/vae/tools/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


def get_random_inputs(dtype, device, bs: int = 2):
torch.random.manual_seed(42)
height = 1024
width = 1024
return torch.rand(bs, 4, width // 8, height // 8, dtype=dtype).to(device)
2 changes: 1 addition & 1 deletion sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@
"sharded_cat",
"sharded_sum",
"softmax",
"squeeze",
"to",
"transfer_to_logical_device",
"transpose",
"unflatten",
"unshard",
"unsqueeze",
"squeeze",
"view",
"view_as_complex",
"view_as_real",
Expand Down
10 changes: 5 additions & 5 deletions sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,11 @@ def size(self, dim: Optional[int] = None) -> tuple[int]:
if dim is None:
return tuple(self.shape)
return self.shape[dim]

def squeeze(self, dim: Optional[int] = None) -> "AnyTensor":
from ..ops import squeeze

return squeeze(self, dim)

def transpose(self, dim0: int, dim1: int) -> "AnyTensor":
from ..ops import transpose
Expand All @@ -398,11 +403,6 @@ def unsqueeze(self, dim: int) -> "AnyTensor":

return unsqueeze(self, dim)

def squeeze(self, dim: Optional[int] = None) -> "AnyTensor":
from ..ops import squeeze

return squeeze(self, dim)

def view(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor":
from ..ops import view

Expand Down

0 comments on commit f95a0f0

Please sign in to comment.