Skip to content

Commit 4a97edc

Browse files
authored
Add note about RoPE implementation (#46)
1 parent 4c3377d commit 4a97edc

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

reasoning_from_scratch/qwen3.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,58 @@ def forward(self, x, mask, cos, sin, start_pos=0, cache=None):
212212
return self.out_proj(context), next_cache
213213

214214

215+
# ==============================================================================
216+
# RoPE implementation summary
217+
#
218+
#
219+
# There are two common styles to implement RoPE, which are
220+
# mathematically equivalent;
221+
# they mainly differ in how the rotation matrix pairs dimensions.
222+
#
223+
# 1) Split-halves style (this repo, Hugging Face Transformers):
224+
#
225+
# For hidden dim d = 8 (example):
226+
#
227+
# [ x0 x1 x2 x3 x4 x5 x6 x7 ]
228+
# │ │ │ │ │ │ │ │
229+
# ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼
230+
# cos cos cos cos sin sin sin sin
231+
#
232+
# Rotation matrix:
233+
#
234+
# [ cosθ -sinθ 0 0 ... ]
235+
# [ sinθ cosθ 0 0 ... ]
236+
# [ 0 0 cosθ -sinθ ... ]
237+
# [ 0 0 sinθ cosθ ... ]
238+
# ...
239+
#
240+
# Here, the embedding dims are split into two halves and then
241+
# each one is rotated in blocks.
242+
#
243+
#
244+
# 2) Interleaved (even/odd) style (original paper, Llama repo):
245+
#
246+
# For hidden dim d = 8 (example):
247+
#
248+
# [ x0 x1 x2 x3 x4 x5 x6 x7 ]
249+
# │ │ │ │ │ │ │ │
250+
# ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼
251+
# cos sin cos sin cos sin cos sin
252+
#
253+
# Rotation matrix:
254+
# [ cosθ -sinθ 0 0 ... ]
255+
# [ sinθ cosθ 0 0 ... ]
256+
# [ 0 0 cosθ -sinθ ... ]
257+
# [ 0 0 sinθ cosθ ... ]
258+
# ...
259+
#
260+
# Here, embedding dims are interleaved as even/odd cosine/sine pairs.
261+
#
262+
# Both layouts encode the same relative positions; the only difference is how
263+
# dimensions are paired.
264+
# ==============================================================================
265+
266+
215267
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
216268
assert head_dim % 2 == 0, "Embedding dimension must be even"
217269

0 commit comments

Comments
 (0)