@@ -136,7 +136,7 @@ def _add_quantization_scale_inv_tensors(
136
136
137
137
def to_hf (self , state_dict : dict [str , Any ]) -> dict [str , Any ]:
138
138
"""
139
- 1. Quantize the weights from float32 to float8.
139
+ 1. When saving HF checkpoints, quantize the weights from float32 to float8.
140
140
2. Convert between the HF shape and the torchtitan shape.
141
141
3. Split the GroupedExperts' weight into seprate expert's wegiht.
142
142
"""
@@ -149,7 +149,6 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
149
149
continue
150
150
151
151
if "moe.experts" in key :
152
- # model.layers.3.mlp.experts.0.down_proj.weight
153
152
abstract_key = re .sub (r"(\d+)" , "{}" , key , count = 1 )
154
153
layer_num = re .search (r"\d+" , key ).group (0 )
155
154
new_abstract_key = to_hf_map [abstract_key ]
@@ -188,7 +187,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
188
187
189
188
def from_hf (self , hf_state_dict : dict [str , Any ]) -> dict [str , Any ]:
190
189
"""
191
- 1. Dequantize the weights from float8 to float32.
190
+ 1. When loading from HF checkpoint, dequantize the weights from float8 to float32.
192
191
2. Convert between the HF shape and the torchtitan shape.
193
192
3. Concate seprate expert's wegiht into GroupedExperts' weight.
194
193
"""
0 commit comments