Skip to content

Commit 1f1b16d

Browse files
committed
remove numpy, avoid tolist in gatherd_idxs
1 parent 1daca4c commit 1f1b16d

File tree

1 file changed

+9
-8
lines changed
  • torchtitan/experiments/deepseek_v3

1 file changed

+9
-8
lines changed

torchtitan/experiments/deepseek_v3/model.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
import math
3030
from typing import Optional, Tuple
3131

32-
import numpy as np
33-
3432
import torch
3533
import torch.distributed as dist
3634

@@ -622,12 +620,15 @@ def moe_forward(self, x, topk_ids, topk_weight):
622620
# the tokens in `gathered_tokens` are headed for. This part doesn't need
623621
# gradient.
624622
with torch.no_grad():
625-
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
626-
s = 0
627-
# TODO: remove `tolist()`
628-
for i, k in enumerate(tokens_per_expert_group.tolist()):
629-
gatherd_idxs[s : s + k] = i % self.experts_per_rank
630-
s += k
623+
gatherd_idxs = torch.arange(
624+
tokens_per_expert_group.numel(),
625+
device=tokens_per_expert_group.device,
626+
)
627+
gatherd_idxs = (
628+
gatherd_idxs.repeat_interleave(tokens_per_expert_group)
629+
% self.experts_per_rank
630+
) + self.ep_rank * self.experts_per_rank
631+
631632

632633
# Prepare buffer for tokens processed by experts
633634
if self.shuffle_method == "symm_mem":

0 commit comments

Comments
 (0)