|
25 | 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
26 | 26 | # See the License for the specific language governing permissions and |
27 | 27 | # limitations under the License. |
28 | | -""" PyTorch DeepSeek model.""" |
| 28 | +"""PyTorch DeepSeek model.""" |
| 29 | + |
29 | 30 | import math |
30 | 31 | from typing import Optional, Tuple |
31 | 32 |
|
32 | 33 | import torch |
33 | 34 | import torch.distributed as dist |
34 | | - |
35 | 35 | import torch.distributed._symmetric_memory as symm_mem |
36 | 36 | import torch.nn.functional as F |
37 | 37 | import torch.utils.checkpoint |
38 | | - |
39 | 38 | from attn_mask_utils import _prepare_4d_causal_attention_mask |
40 | 39 | from model_config import ModelArgs |
41 | 40 | from symm_mem_recipes import OnDeviceAllToAllV |
@@ -401,9 +400,7 @@ def forward(self, hidden_states): |
401 | 400 | ) # [n, n_group] |
402 | 401 | group_idx = torch.topk( |
403 | 402 | group_scores, k=self.topk_group, dim=-1, sorted=False |
404 | | - )[ |
405 | | - 1 |
406 | | - ] # [n, top_k_group] |
| 403 | + )[1] # [n, top_k_group] |
407 | 404 | group_mask = torch.zeros_like(group_scores) # [n, n_group] |
408 | 405 | group_mask.scatter_(1, group_idx, 1) # [n, n_group] |
409 | 406 | score_mask = ( |
@@ -629,7 +626,6 @@ def moe_forward(self, x, topk_ids, topk_weight): |
629 | 626 | % self.experts_per_rank |
630 | 627 | ) + self.ep_rank * self.experts_per_rank |
631 | 628 |
|
632 | | - |
633 | 629 | # Prepare buffer for tokens processed by experts |
634 | 630 | if self.shuffle_method == "symm_mem": |
635 | 631 | # Take necessary space from `token_gather_buf` symm mem because we are |
@@ -1002,9 +998,9 @@ def __init__(self, config: ModelArgs): |
1002 | 998 | self.vocab_size = config.vocab_size |
1003 | 999 |
|
1004 | 1000 | # Creating model parts related to my stage |
1005 | | - assert ( |
1006 | | - config.stage_idx < config.num_stages |
1007 | | - ), f"Stage {config.stage_idx} is not in the model" |
| 1001 | + assert config.stage_idx < config.num_stages, ( |
| 1002 | + f"Stage {config.stage_idx} is not in the model" |
| 1003 | + ) |
1008 | 1004 | print(f"Creating model stage {config.stage_idx} of {config.num_stages}") |
1009 | 1005 |
|
1010 | 1006 | self.embed_tokens = ( |
|
0 commit comments