@@ -212,6 +212,58 @@ def forward(self, x, mask, cos, sin, start_pos=0, cache=None):
212212 return self .out_proj (context ), next_cache
213213
214214
215+ # ==============================================================================
216+ # RoPE implementation summary
217+ #
218+ #
219+ # There are two common styles to implement RoPE, which are
220+ # mathematically equivalent;
221+ # they mainly differ in how the rotation matrix pairs dimensions.
222+ #
223+ # 1) Split-halves style (this repo, Hugging Face Transformers):
224+ #
225+ # For hidden dim d = 8 (example):
226+ #
227+ # [ x0 x1 x2 x3 x4 x5 x6 x7 ]
228+ # │ │ │ │ │ │ │ │
229+ # ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼
230+ # cos cos cos cos sin sin sin sin
231+ #
232+ # Rotation matrix:
233+ #
234+ # [ cosθ -sinθ 0 0 ... ]
235+ # [ sinθ cosθ 0 0 ... ]
236+ # [ 0 0 cosθ -sinθ ... ]
237+ # [ 0 0 sinθ cosθ ... ]
238+ # ...
239+ #
240+ # Here, the embedding dims are split into two halves and then
241+ # each one is rotated in blocks.
242+ #
243+ #
244+ # 2) Interleaved (even/odd) style (original paper, Llama repo):
245+ #
246+ # For hidden dim d = 8 (example):
247+ #
248+ # [ x0 x1 x2 x3 x4 x5 x6 x7 ]
249+ # │ │ │ │ │ │ │ │
250+ # ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼
251+ # cos sin cos sin cos sin cos sin
252+ #
253+ # Rotation matrix:
254+ # [ cosθ -sinθ 0 0 ... ]
255+ # [ sinθ cosθ 0 0 ... ]
256+ # [ 0 0 cosθ -sinθ ... ]
257+ # [ 0 0 sinθ cosθ ... ]
258+ # ...
259+ #
260+ # Here, embedding dims are interleaved as even/odd cosine/sine pairs.
261+ #
262+ # Both layouts encode the same relative positions; the only difference is how
263+ # dimensions are paired.
264+ # ==============================================================================
265+
266+
215267def compute_rope_params (head_dim , theta_base = 10_000 , context_length = 4096 , dtype = torch .float32 ):
216268 assert head_dim % 2 == 0 , "Embedding dimension must be even"
217269
0 commit comments