|
25 | 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
26 | 26 | # See the License for the specific language governing permissions and
|
27 | 27 | # limitations under the License.
|
28 |
| -""" PyTorch DeepSeek model.""" |
| 28 | +"""PyTorch DeepSeek model.""" |
| 29 | + |
29 | 30 | import math
|
30 | 31 | from typing import Optional, Tuple
|
31 | 32 |
|
32 | 33 | import torch
|
33 | 34 | import torch.distributed as dist
|
34 |
| - |
35 | 35 | import torch.distributed._symmetric_memory as symm_mem
|
36 | 36 | import torch.nn.functional as F
|
37 | 37 | import torch.utils.checkpoint
|
38 |
| - |
39 | 38 | from attn_mask_utils import _prepare_4d_causal_attention_mask
|
40 | 39 | from model_config import ModelArgs
|
41 | 40 | from symm_mem_recipes import OnDeviceAllToAllV
|
@@ -401,9 +400,7 @@ def forward(self, hidden_states):
|
401 | 400 | ) # [n, n_group]
|
402 | 401 | group_idx = torch.topk(
|
403 | 402 | group_scores, k=self.topk_group, dim=-1, sorted=False
|
404 |
| - )[ |
405 |
| - 1 |
406 |
| - ] # [n, top_k_group] |
| 403 | + )[1] # [n, top_k_group] |
407 | 404 | group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
408 | 405 | group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
409 | 406 | score_mask = (
|
@@ -629,7 +626,6 @@ def moe_forward(self, x, topk_ids, topk_weight):
|
629 | 626 | % self.experts_per_rank
|
630 | 627 | ) + self.ep_rank * self.experts_per_rank
|
631 | 628 |
|
632 |
| - |
633 | 629 | # Prepare buffer for tokens processed by experts
|
634 | 630 | if self.shuffle_method == "symm_mem":
|
635 | 631 | # Take necessary space from `token_gather_buf` symm mem because we are
|
@@ -1002,9 +998,9 @@ def __init__(self, config: ModelArgs):
|
1002 | 998 | self.vocab_size = config.vocab_size
|
1003 | 999 |
|
1004 | 1000 | # Creating model parts related to my stage
|
1005 |
| - assert ( |
1006 |
| - config.stage_idx < config.num_stages |
1007 |
| - ), f"Stage {config.stage_idx} is not in the model" |
| 1001 | + assert config.stage_idx < config.num_stages, ( |
| 1002 | + f"Stage {config.stage_idx} is not in the model" |
| 1003 | + ) |
1008 | 1004 | print(f"Creating model stage {config.stage_idx} of {config.num_stages}")
|
1009 | 1005 |
|
1010 | 1006 | self.embed_tokens = (
|
|
0 commit comments