Skip to content

Commit a453229

Browse files
committed
fmt
1 parent 1f1b16d commit a453229

File tree

1 file changed

+6
-10
lines changed
  • torchtitan/experiments/deepseek_v3

1 file changed

+6
-10
lines changed

torchtitan/experiments/deepseek_v3/model.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,16 @@
2525
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2626
# See the License for the specific language governing permissions and
2727
# limitations under the License.
28-
""" PyTorch DeepSeek model."""
28+
"""PyTorch DeepSeek model."""
29+
2930
import math
3031
from typing import Optional, Tuple
3132

3233
import torch
3334
import torch.distributed as dist
34-
3535
import torch.distributed._symmetric_memory as symm_mem
3636
import torch.nn.functional as F
3737
import torch.utils.checkpoint
38-
3938
from attn_mask_utils import _prepare_4d_causal_attention_mask
4039
from model_config import ModelArgs
4140
from symm_mem_recipes import OnDeviceAllToAllV
@@ -401,9 +400,7 @@ def forward(self, hidden_states):
401400
) # [n, n_group]
402401
group_idx = torch.topk(
403402
group_scores, k=self.topk_group, dim=-1, sorted=False
404-
)[
405-
1
406-
] # [n, top_k_group]
403+
)[1] # [n, top_k_group]
407404
group_mask = torch.zeros_like(group_scores) # [n, n_group]
408405
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
409406
score_mask = (
@@ -629,7 +626,6 @@ def moe_forward(self, x, topk_ids, topk_weight):
629626
% self.experts_per_rank
630627
) + self.ep_rank * self.experts_per_rank
631628

632-
633629
# Prepare buffer for tokens processed by experts
634630
if self.shuffle_method == "symm_mem":
635631
# Take necessary space from `token_gather_buf` symm mem because we are
@@ -1002,9 +998,9 @@ def __init__(self, config: ModelArgs):
1002998
self.vocab_size = config.vocab_size
1003999

10041000
# Creating model parts related to my stage
1005-
assert (
1006-
config.stage_idx < config.num_stages
1007-
), f"Stage {config.stage_idx} is not in the model"
1001+
assert config.stage_idx < config.num_stages, (
1002+
f"Stage {config.stage_idx} is not in the model"
1003+
)
10081004
print(f"Creating model stage {config.stage_idx} of {config.num_stages}")
10091005

10101006
self.embed_tokens = (

0 commit comments

Comments
 (0)