@@ -319,12 +319,16 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
319
319
assert embed_dim % 4 == 0
320
320
321
321
# use half of dimensions to encode grid_h
322
- emb_h = get_1d_rotary_pos_embed (embed_dim // 2 , grid [0 ].reshape (- 1 ), use_real = use_real ) # (H*W, D/4)
323
- emb_w = get_1d_rotary_pos_embed (embed_dim // 2 , grid [1 ].reshape (- 1 ), use_real = use_real ) # (H*W, D/4)
322
+ emb_h = get_1d_rotary_pos_embed (
323
+ embed_dim // 2 , grid [0 ].reshape (- 1 ), use_real = use_real
324
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
325
+ emb_w = get_1d_rotary_pos_embed (
326
+ embed_dim // 2 , grid [1 ].reshape (- 1 ), use_real = use_real
327
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
324
328
325
329
if use_real :
326
- cos = torch .cat ([emb_h [0 ], emb_w [0 ]], dim = 1 ) # (H*W, D/2 )
327
- sin = torch .cat ([emb_h [1 ], emb_w [1 ]], dim = 1 ) # (H*W, D/2 )
330
+ cos = torch .cat ([emb_h [0 ], emb_w [0 ]], dim = 1 ) # (H*W, D)
331
+ sin = torch .cat ([emb_h [1 ], emb_w [1 ]], dim = 1 ) # (H*W, D)
328
332
return cos , sin
329
333
else :
330
334
emb = torch .cat ([emb_h , emb_w ], dim = 1 ) # (H*W, D/2)
@@ -371,6 +375,8 @@ def get_1d_rotary_pos_embed(
371
375
Returns:
372
376
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
373
377
"""
378
+ assert dim % 2 == 0
379
+
374
380
if isinstance (pos , int ):
375
381
pos = np .arange (pos )
376
382
theta = theta * ntk_factor
0 commit comments