Skip to content

Commit f1fa123

Browse files
authored
Merge branch 'main' into layerwise-upcasting
2 parents 9b411e5 + 3e46043 commit f1fa123

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,21 @@ def __init__(
6868
self.height, self.width = height // patch_size, width // patch_size
6969
self.base_size = height // patch_size
7070

71+
def pe_selection_index_based_on_dim(self, h, w):
72+
# select subset of positional embedding based on H, W, where H, W is size of latent
73+
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
74+
# because original input are in flattened format, we have to flatten this 2d grid as well.
75+
h_p, w_p = h // self.patch_size, w // self.patch_size
76+
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
77+
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
78+
original_pe_indexes = original_pe_indexes.view(h_max, w_max)
79+
starth = h_max // 2 - h_p // 2
80+
endh = starth + h_p
81+
startw = w_max // 2 - w_p // 2
82+
endw = startw + w_p
83+
original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
84+
return original_pe_indexes.flatten()
85+
7186
def forward(self, latent):
7287
batch_size, num_channels, height, width = latent.size()
7388
latent = latent.view(
@@ -80,7 +95,8 @@ def forward(self, latent):
8095
)
8196
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
8297
latent = self.proj(latent)
83-
return latent + self.pos_embed
98+
pe_index = self.pe_selection_index_based_on_dim(height, width)
99+
return latent + self.pos_embed[:, pe_index]
84100

85101

86102
# Taken from the original Aura flow inference code.

src/diffusers/utils/loading_utils.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import tempfile
33
from typing import Callable, List, Optional, Union
4+
from urllib.parse import unquote, urlparse
45

56
import PIL.Image
67
import PIL.ImageOps
@@ -80,12 +81,22 @@ def load_video(
8081
)
8182

8283
if is_url:
83-
video_data = requests.get(video, stream=True).raw
84-
suffix = os.path.splitext(video)[1] or ".mp4"
84+
response = requests.get(video, stream=True)
85+
if response.status_code != 200:
86+
raise ValueError(f"Failed to download video. Status code: {response.status_code}")
87+
88+
parsed_url = urlparse(video)
89+
file_name = os.path.basename(unquote(parsed_url.path))
90+
91+
suffix = os.path.splitext(file_name)[1] or ".mp4"
8592
video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name
93+
8694
was_tempfile_created = True
95+
96+
video_data = response.iter_content(chunk_size=8192)
8797
with open(video_path, "wb") as f:
88-
f.write(video_data.read())
98+
for chunk in video_data:
99+
f.write(chunk)
89100

90101
video = video_path
91102

0 commit comments

Comments
 (0)