Skip to content

Commit 69917df

Browse files
aahouzicebtenzzrecompilade
authored
py : fix StableLM conversion after config.json changes (#5703)
* Fix issues during StableLM models conversion * Fix hard coded layer_norm_eps * Support layer_norm_eps for LlavaStableLM Co-authored-by: Jared Van Bortel <[email protected]> * Add missing parenthesis Co-authored-by: Jared Van Bortel <[email protected]> * Support rotary_factor for LlavaStableLM Co-authored-by: Jared Van Bortel <[email protected]> * fix typo * Add StableLMEpochForCausalLM for safety Co-authored-by: compilade <[email protected]> * Add StableLMEpochForCausalLM for safety 2 Co-authored-by: compilade <[email protected]> --------- Co-authored-by: Jared Van Bortel <[email protected]> Co-authored-by: Jared Van Bortel <[email protected]> Co-authored-by: compilade <[email protected]>
1 parent 9e359a4 commit 69917df

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

convert-hf-to-gguf.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def from_model_architecture(model_architecture):
192192
return RefactModel
193193
if model_architecture == "PersimmonForCausalLM":
194194
return PersimmonModel
195-
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
195+
if model_architecture in ("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
196196
return StableLMModel
197197
if model_architecture == "QWenLMHeadModel":
198198
return QwenModel
@@ -253,7 +253,7 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
253253
return gguf.MODEL_ARCH.REFACT
254254
if arch == "PersimmonForCausalLM":
255255
return gguf.MODEL_ARCH.PERSIMMON
256-
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
256+
if arch in ("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
257257
return gguf.MODEL_ARCH.STABLELM
258258
if arch == "QWenLMHeadModel":
259259
return gguf.MODEL_ARCH.QWEN
@@ -1074,10 +1074,11 @@ def set_gguf_parameters(self):
10741074
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
10751075
self.gguf_writer.add_block_count(block_count)
10761076
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
1077-
self.gguf_writer.add_rope_dimension_count(int(hparams["rope_pct"] * (hparams["hidden_size"] // hparams["num_attention_heads"])))
1077+
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
1078+
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
10781079
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
10791080
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
1080-
self.gguf_writer.add_layer_norm_eps(1e-5)
1081+
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"]))
10811082

10821083

10831084
class MixtralModel(Model):

0 commit comments

Comments
 (0)