Skip to content

Commit 461efc5

Browse files
[fix code annotation] Adjust the dimensions of the rotary positional embedding. (huggingface#8890)
* 2d rotary pos emb dim * make style --------- Co-authored-by: haofanwang <[email protected]>
1 parent 3b04cdc commit 461efc5

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/diffusers/models/embeddings.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,12 +319,16 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
319319
assert embed_dim % 4 == 0
320320

321321
# use half of dimensions to encode grid_h
322-
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
323-
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
322+
emb_h = get_1d_rotary_pos_embed(
323+
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
324+
) # (H*W, D/2) if use_real else (H*W, D/4)
325+
emb_w = get_1d_rotary_pos_embed(
326+
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
327+
) # (H*W, D/2) if use_real else (H*W, D/4)
324328

325329
if use_real:
326-
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
327-
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
330+
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
331+
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
328332
return cos, sin
329333
else:
330334
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
@@ -371,6 +375,8 @@ def get_1d_rotary_pos_embed(
371375
Returns:
372376
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
373377
"""
378+
assert dim % 2 == 0
379+
374380
if isinstance(pos, int):
375381
pos = np.arange(pos)
376382
theta = theta * ntk_factor

0 commit comments

Comments
 (0)