-
Notifications
You must be signed in to change notification settings - Fork 462
[WIP] Add state_dict converter for DeepSeekv3 in torchtitan #1538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
wwwjn
wants to merge
6
commits into
main
Choose a base branch
from
dsv3-state-dict
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+292
−3
Draft
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import torch | ||
from torchtitan.tools.logging import logger | ||
|
||
# Fixed block size of 128x128 as specified in the algorithm | ||
BLOCK_SIZE = 128 | ||
|
||
|
||
def calculate_scale_shape( | ||
weight: torch.Tensor, BLOCK_SIZE: int = BLOCK_SIZE | ||
) -> torch.Size: | ||
# Calculate the scale tensor shape | ||
orig_shape = weight.shape | ||
|
||
# Calculate number of blocks needed | ||
block_rows = (orig_shape[0] + BLOCK_SIZE - 1) // BLOCK_SIZE | ||
block_cols = (orig_shape[1] + BLOCK_SIZE - 1) // BLOCK_SIZE | ||
|
||
# Verify scale_inv shape matches expected block dimensions | ||
expected_scale_shape = torch.Size((block_rows, block_cols)) | ||
|
||
return expected_scale_shape | ||
|
||
|
||
def dequantize_fp8(weight: torch.Tensor, scale_inv: torch.Tensor, dtype=torch.bfloat16, BLOCK_SIZE: int = BLOCK_SIZE) -> torch.Tensor: | ||
# Convert to float32 for computation | ||
float_weight = weight.to(torch.float32) | ||
# Get original dimensions | ||
orig_shape = weight.shape | ||
|
||
# Verify scale_inv shape matches expected block dimensions | ||
expected_scale_shape = calculate_scale_shape(weight, BLOCK_SIZE) | ||
block_rows, block_cols = expected_scale_shape | ||
if scale_inv.shape != expected_scale_shape: | ||
logger.warning( | ||
f"scale_inv shape {scale_inv.shape} doesn't match expected shape {expected_scale_shape}" | ||
) | ||
|
||
# NOTE: This might cause OOM if the model is too large | ||
# Copy the weight tensor to make it also a DTensor | ||
dequantized = float_weight.detach().clone().to(dtype=dtype) | ||
|
||
# Apply scaling factors to each block | ||
for i in range(block_rows): | ||
row_start = i * BLOCK_SIZE | ||
row_end = min(row_start + BLOCK_SIZE, orig_shape[0]) | ||
|
||
for j in range(block_cols): | ||
col_start = j * BLOCK_SIZE | ||
col_end = min(col_start + BLOCK_SIZE, orig_shape[1]) | ||
|
||
# Get the block | ||
block = float_weight[row_start:row_end, col_start:col_end] | ||
|
||
scale = scale_inv[i, j] | ||
block = block * scale | ||
|
||
# Explicitly convert block to dtype | ||
block_converted = block.to(dtype=torch.float32) | ||
# Store the dequantized block | ||
dequantized[row_start:row_end, col_start:col_end] = block_converted | ||
|
||
return dequantized |
223 changes: 223 additions & 0 deletions
223
torchtitan/models/deepseek_v3/model/state_dict_adapter.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import re | ||
from typing import Any | ||
|
||
from torchtitan.protocols.state_dict_adapter import StateDictAdapter | ||
|
||
from .args import DeepSeekV3ModelArgs | ||
import torch | ||
from .quantization import dequantize_fp8, calculate_scale_shape | ||
|
||
|
||
class DeepSeekV3StateDictAdapter(StateDictAdapter): | ||
def __init__(self, model_args: DeepSeekV3ModelArgs): | ||
""" | ||
StateDictAdapter for DeepSeekV3 model. | ||
NOTE: Now we observed the rotary embedding difference in torchtitan and huggingface. And this need to | ||
be fixed to make the numerical results consistent between torchtitan and huggingface. | ||
""" | ||
self.model_args = model_args | ||
self.from_hf_map = { | ||
"model.embed_tokens.weight": "tok_embeddings.weight", | ||
# Attention Module | ||
"model.layers.{}.self_attn.q_a_proj.weight": "layers.{}.attention.wq_a.weight", | ||
"model.layers.{}.self_attn.q_a_layernorm.weight": "layers.{}.attention.q_norm.weight", | ||
"model.layers.{}.self_attn.q_b_proj.weight": "layers.{}.attention.wq_b.weight", | ||
"model.layers.{}.self_attn.kv_a_proj_with_mqa.weight": "layers.{}.attention.wkv_a.weight", | ||
"model.layers.{}.self_attn.kv_a_layernorm.weight": "layers.{}.attention.kv_norm.weight", | ||
"model.layers.{}.self_attn.kv_b_proj.weight": "layers.{}.attention.wkv_b.weight", | ||
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", | ||
# MLP Module | ||
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", | ||
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", | ||
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", | ||
# Transfomer Layer | ||
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", | ||
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", | ||
# MoE Module | ||
"model.layers.{}.mlp.experts.{}.gate_proj.weight": "layers.{}.moe.experts.w1", | ||
"model.layers.{}.mlp.experts.{}.up_proj.weight": "layers.{}.moe.experts.w3", | ||
"model.layers.{}.mlp.experts.{}.down_proj.weight": "layers.{}.moe.experts.w2", | ||
"model.layers.{}.mlp.gate.weight": "layers.{}.moe.router.gate.weight", | ||
"model.layers.{}.mlp.shared_experts.gate_proj.weight": "layers.{}.moe.shared_expert.w1", | ||
"model.layers.{}.mlp.shared_experts.up_proj.weight": "layers.{}.moe.shared_expert.w3", | ||
"model.layers.{}.mlp.shared_experts.down_proj.weight": "layers.{}.moe.shared_expert.w2", | ||
"model.norm.weight": "norm.weight", | ||
"lm_head.weight": "output.weight", | ||
} | ||
|
||
|
||
def _split_experts_weights(self, weight: torch.Tensor, n_experts: int) -> list[torch.Tensor]: | ||
""" | ||
Split the weights of the experts into a list of tensors. | ||
""" | ||
split_weight = torch.split(weight, weight.shape[0] // n_experts, dim=0) | ||
return split_weight | ||
|
||
def _concatenate_expert_weights(self, expert_weights_by_layer: dict[str, Any], n_experts: int) -> torch.Tensor: | ||
""" | ||
Concatenate the weights of seprate experts into GroupedExpert weights. | ||
""" | ||
for layer, abstract_keys in list(expert_weights_by_layer.items()): | ||
for abstract_key, experts in list(abstract_keys.items()): | ||
# If we have all the experts for this abstract_key, concatenate them | ||
if len(experts) == n_experts: | ||
sorted_expert_ids = sorted(experts.keys()) | ||
sorted_experts = [experts[i] for i in sorted_expert_ids] | ||
|
||
# Here we need transpose because the torchtitan used nn.Linear() while HF used nn.Parameter | ||
stacked_tensor = torch.stack(sorted_experts, dim=0).transpose( | ||
1, 2 | ||
) | ||
|
||
# Remove these experts from the tracking dict to free memory | ||
del expert_weights_by_layer[layer][abstract_key] | ||
if not expert_weights_by_layer[layer]: | ||
del expert_weights_by_layer[layer] | ||
|
||
return stacked_tensor | ||
|
||
return None | ||
|
||
def _quantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: | ||
""" | ||
Quantize the weights from float32 to float8. Export to HF f | ||
""" | ||
pass | ||
|
||
def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: | ||
""" | ||
Dequantize the weights from float8 to float32. | ||
""" | ||
|
||
scale_inv_keys = [] | ||
for key, weight in state_dict.items(): | ||
if key.endswith(".weight") and key + "_scale_inv" in state_dict: | ||
scale_inv = state_dict[key + "_scale_inv"] | ||
dequantized_weight = dequantize_fp8(weight, scale_inv, dtype=torch.float32) | ||
# update the weight and remove the scale_inv tensor | ||
state_dict[key] = dequantized_weight | ||
scale_inv_keys.append(key + "_scale_inv") | ||
|
||
for key in scale_inv_keys: | ||
state_dict.pop(key) | ||
|
||
return state_dict | ||
|
||
def _add_quantization_scale_inv_tensors(self, state_dict: dict[str, Any]) -> dict[str, Any]: | ||
""" | ||
Add quantization scale tensors the state_dict. | ||
""" | ||
non_quantized_keys = ["input_layernorm.weight", "post_attention_layernorm.weight", "norm.weight", "lm_head.weight", "embed_tokens.weight", "mlp.gate.weight"] | ||
|
||
weight_scale_inv_state_dict = {} | ||
for key, value in state_dict.items(): | ||
if key.endswith(".weight") and not any(non_quantized_key in key for non_quantized_key in non_quantized_keys): | ||
expected_scale_shape = calculate_scale_shape(value) | ||
# add weight_scale_inv to the state_dict | ||
weight_scale_inv_state_dict[key + "_scale_inv"] = torch.zeros(expected_scale_shape, dtype=torch.float32) | ||
|
||
state_dict.update(weight_scale_inv_state_dict) | ||
return state_dict | ||
|
||
def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: | ||
""" | ||
1. Quantize the weights from float32 to float8. | ||
2. Convert between the HF shape and the torchtitan shape. | ||
3. Split the GroupedExperts' weight into seprate expert's wegiht. | ||
""" | ||
|
||
to_hf_map = {v: k for k, v in self.from_hf_map.items()} | ||
|
||
hf_state_dict = {} | ||
|
||
for key, value in state_dict.items(): | ||
if "moe.expert_bias" in key or "moe.tokens_per_expert" in key: | ||
continue | ||
|
||
if "moe.experts" in key: | ||
abstract_key = re.sub(r"(\d+)", "{}", key, count=1) | ||
layer_num = re.search(r"\d+", key).group(0) | ||
new_key = to_hf_map[abstract_key] | ||
|
||
# Split expert weights into seperate expert weights | ||
split_values = self._split_experts_weights(value, self.model_args.n_routed_experts) | ||
for expert_num in range(0, self.model_args.n_routed_experts): | ||
new_key = new_key.format(layer_num, expert_num) | ||
# We need to transpose the weight because the torchtitan used nn.Linear() while HF used nn.Parameter() | ||
hf_state_dict[new_key] = split_values[expert_num].squeeze().transpose(0, 1) | ||
|
||
elif "layers" in key: | ||
abstract_key = re.sub(r"(\d+)", "{}", key, count=1) | ||
layer_num = re.search(r"\d+", key).group(0) | ||
new_key = to_hf_map[abstract_key] | ||
new_key = new_key.format(layer_num) | ||
|
||
# Special case for `shared_expert`: torchtitan uses nn.Linear, and HF uses nn.Parameter | ||
# torchtitan shape: (1, s[1], s[2]) -> HF shape: (s[2], s[1]) | ||
if "shared_expert" in key: | ||
value = value.squeeze(0).transpose(0, 1) | ||
|
||
hf_state_dict[new_key] = value | ||
|
||
else: | ||
new_key = to_hf_map[key] | ||
hf_state_dict[new_key] = value | ||
|
||
hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors(hf_state_dict) | ||
return hf_state_dict_with_scale_inv | ||
|
||
def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: | ||
""" | ||
1. Dequantize the weights from float8 to float32. | ||
2. Convert between the HF shape and the torchtitan shape. | ||
3. Concate seprate expert's wegiht into GroupedExperts' weight. | ||
""" | ||
|
||
# dequantize the tensor in state_dict and remove the scale_inv tensor | ||
hf_state_dict = self._dequantize(hf_state_dict) | ||
state_dict = {} | ||
|
||
expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} | ||
|
||
for key, value in hf_state_dict.items(): | ||
if "mlp.experts" in key: | ||
abstract_key = re.sub(r"(\d+)", "{}", key, count=2) | ||
layer_num, expert_num = re.findall(r"\d+", key) | ||
new_key = self.from_hf_map[abstract_key] | ||
new_key = new_key.format(layer_num) | ||
|
||
# Store the expert's weight in expert_weights_by_layer for concating later. | ||
if layer_num not in expert_weights_by_layer: | ||
expert_weights_by_layer[layer_num] = {} | ||
if abstract_key not in expert_weights_by_layer[layer_num]: | ||
expert_weights_by_layer[layer_num][abstract_key] = {} | ||
expert_weights_by_layer[layer_num][abstract_key][expert_num] = value | ||
|
||
# try to concat the expert's weight into GroupedExperts' weight. | ||
stacked_value = self._concatenate_expert_weights(expert_weights_by_layer, self.model_args.n_routed_experts) | ||
if stacked_value is not None: | ||
state_dict[new_key] = stacked_value | ||
|
||
elif "layers" in key: | ||
abstract_key = re.sub(r"(\d+)", "{}", key, count=1) | ||
layer_num = re.search(r"\d+", key).group(0) | ||
new_key = self.from_hf_map[abstract_key] | ||
new_key = new_key.format(layer_num) | ||
|
||
# Special case for `shared_expert`: torchtitan uses nn.Linear, and HF uses nn.Parameter | ||
# HF shape: (s[1], s[2]) -> torchtitan shape: (1, s[2], s[1]) | ||
if "shared_experts" in key: | ||
value = value.transpose(0, 1).unsqueeze(0) | ||
|
||
state_dict[new_key] = value | ||
else: | ||
new_key = self.from_hf_map[key] | ||
state_dict[new_key] = value | ||
|
||
return state_dict |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's the tensor slice operation @ankitageorge