Skip to content

Commit b0de162

Browse files
guangy10Github Executorch
andauthored
Cover more model architectures (#107)
Co-authored-by: Github Executorch <[email protected]>
1 parent e6d1d6f commit b0de162

12 files changed

+790
-16
lines changed

README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,16 +151,26 @@ We currently support a wide range of popular transformer models, including encod
151151
- [Eurobert](https://huggingface.co/EuroBERT/EuroBERT-210m): `EuroBERT-210m` and its variants
152152
- [Roberta](https://huggingface.co/FacebookAI/xlm-roberta-base): FacebookAI's `xlm-roberta-base` and its variants
153153
#### Decoder-only models
154+
- [Codegen](https://huggingface.co/Salesforce/codegen-350M-mono): Salesforce's `codegen-350M-mono` and its variants
154155
- [Gemma](https://huggingface.co/google/gemma-2b): `Gemma-2b` and its variants
155156
- [Gemma2](https://huggingface.co/google/gemma-2-2b): `Gemma-2-2b` and its variants
156-
- [Gemma3](https://huggingface.co/google/gemma-3-1b-it): `Gemma-3-1b` and its variants *(requires `transformers >= 4.52.0`)*
157+
- [Gemma3](https://huggingface.co/google/gemma-3-1b-it): `Gemma-3-1b` and its variants
158+
- [Glm](https://huggingface.co/THUDM/glm-edge-1.5b-chat): `glm-edge-1.5b` and its variants
159+
- [Gpt2](https://huggingface.co/AI-Sweden-Models/gpt-sw3-126m): `gpt-sw3-126m` and its variants
160+
- [GptJ](https://huggingface.co/Milos/slovak-gpt-j-405M): `gpt-j-405M` and its variants
161+
- [GptNeoX](https://huggingface.co/EleutherAI/pythia-14m): EleutherAI's `pythia-14m` and its variants
162+
- [GptNeoXJapanese](https://huggingface.co/abeja/gpt-neox-japanese-2.7b): `gpt-neox-japanese-2.7b` and its variants
163+
- [Granite](https://huggingface.co/ibm-granite/granite-3.3-2b-instruct): `granite-3.3-2b-instruct` and its variants
157164
- [Llama](https://huggingface.co/meta-llama/Llama-3.2-1B): `Llama-3.2-1B` and its variants
165+
- [Mistral](https://huggingface.co/ministral/Ministral-3b-instruct): `Ministral-3b-instruct` and its variants
158166
- [Qwen2](https://huggingface.co/Qwen/Qwen2.5-0.5B): `Qwen2.5-0.5B` and its variants
159167
- [Qwen3](https://huggingface.co/Qwen/Qwen3-0.6B): `Qwen3-0.6B`, `Qwen3-Embedding-0.6B` and other variants
160168
- [Olmo](https://huggingface.co/allenai/OLMo-1B-hf): `OLMo-1B-hf` and its variants
169+
- [Phi](https://huggingface.co/johnsnowlabs/JSL-MedPhi2-2.7B): `JSL-MedPhi2-2.7B` and its variants
161170
- [Phi4](https://huggingface.co/microsoft/Phi-4-mini-instruct): `Phi-4-mini-instruct` and its variants
162171
- [Smollm](https://huggingface.co/HuggingFaceTB/SmolLM2-135M): 🤗 `SmolLM2-135M` and its variants
163172
- [Smollm3](https://huggingface.co/HuggingFaceTB/SmolLM3-3B): 🤗 `SmolLM3-3B` and its variants
173+
- [Starcoder2](https://huggingface.co/bigcode/starcoder2-3b): `starcoder2-3b` and its variants
164174
#### Encoder-decoder models
165175
- [T5](https://huggingface.co/google-t5/t5-small): Google's `T5` and its variants
166176

optimum/exporters/executorch/tasks/causal_lm.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,22 +71,59 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
7171
if hasattr(config, "use_cache") and config.use_cache is False:
7272
config.use_cache = True
7373

74-
eager_model = AutoModelForCausalLM.from_pretrained(
74+
def _load_eager_pretrained(
7575
model_name_or_path,
76-
device_map=device,
77-
torch_dtype=dtype,
78-
config=config,
79-
attn_implementation=attn_implementation,
80-
generation_config=GenerationConfig(
81-
use_cache=True,
82-
cache_implementation=cache_implementation,
83-
max_length=max_length,
84-
cache_config={
85-
"batch_size": batch_size,
86-
"max_cache_len": max_length,
87-
},
88-
),
89-
)
76+
device,
77+
dtype,
78+
config,
79+
attn_implementation,
80+
cache_implementation,
81+
batch_size,
82+
max_length,
83+
):
84+
eager_model = AutoModelForCausalLM.from_pretrained(
85+
model_name_or_path,
86+
device_map=device,
87+
torch_dtype=dtype,
88+
config=config,
89+
attn_implementation=attn_implementation,
90+
generation_config=GenerationConfig(
91+
use_cache=True,
92+
cache_implementation=cache_implementation,
93+
max_length=max_length,
94+
cache_config={
95+
"batch_size": batch_size,
96+
"max_cache_len": max_length,
97+
},
98+
),
99+
)
100+
return eager_model
101+
102+
try:
103+
eager_model = _load_eager_pretrained(
104+
model_name_or_path,
105+
device,
106+
dtype,
107+
config,
108+
attn_implementation,
109+
cache_implementation,
110+
batch_size,
111+
max_length,
112+
)
113+
except ValueError as e:
114+
if "torch.nn.functional.scaled_dot_product_attention" in str(e):
115+
logging.info("⚠ SDPA attention not supported, falling back to eager implementation")
116+
attn_implementation = "eager"
117+
eager_model = _load_eager_pretrained(
118+
model_name_or_path,
119+
device,
120+
dtype,
121+
config,
122+
attn_implementation,
123+
cache_implementation,
124+
batch_size,
125+
max_length,
126+
)
90127

91128
for param in eager_model.parameters():
92129
# Must disable gradient for quantized checkpoint
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import gc
17+
import logging
18+
import os
19+
import unittest
20+
21+
import pytest
22+
import torchao
23+
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
24+
from packaging.version import parse
25+
from transformers import AutoConfig, AutoTokenizer
26+
from transformers.testing_utils import slow
27+
28+
from optimum.executorch import ExecuTorchModelForCausalLM
29+
30+
from ..utils import check_causal_lm_output_quality
31+
32+
33+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
34+
35+
36+
class ExecuTorchModelIntegrationTest(unittest.TestCase):
37+
def __init__(self, *args, **kwargs):
38+
super().__init__(*args, **kwargs)
39+
40+
@slow
41+
@pytest.mark.run_slow
42+
@pytest.mark.skipif(
43+
parse(torchao.__version__) < parse("0.11.0"),
44+
reason="Quantization is only available on torchao >= 0.11.0.",
45+
)
46+
def test_codegen_text_generation_with_8da4w_8we(self):
47+
model_id = "Salesforce/codegen-350M-mono"
48+
prompt = "def hello_world():"
49+
tokenizer = AutoTokenizer.from_pretrained(model_id)
50+
config = AutoConfig.from_pretrained(model_id)
51+
config.bos_token_id = tokenizer.bos_token_id
52+
config.eos_token_id = tokenizer.eos_token_id
53+
model = ExecuTorchModelForCausalLM.from_pretrained(
54+
model_id,
55+
config=config,
56+
recipe="xnnpack",
57+
**{"qlinear": True, "qembeeding": True},
58+
)
59+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
60+
self.assertIsInstance(model.model, ExecuTorchModule)
61+
generated_text = model.text_generation(
62+
tokenizer=tokenizer,
63+
prompt=prompt,
64+
max_seq_len=64,
65+
)
66+
logging.info(f"\nGenerated text:\n\t{generated_text}")
67+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
68+
69+
# Free memory before loading eager for quality check
70+
del model
71+
del tokenizer
72+
gc.collect()
73+
74+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))

tests/models/test_modeling_glm.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import gc
17+
import logging
18+
import os
19+
import unittest
20+
21+
import pytest
22+
import torchao
23+
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
24+
from packaging.version import parse
25+
from transformers import AutoTokenizer
26+
from transformers.testing_utils import slow
27+
28+
from optimum.executorch import ExecuTorchModelForCausalLM
29+
30+
from ..utils import check_causal_lm_output_quality
31+
32+
33+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
34+
35+
36+
class ExecuTorchModelIntegrationTest(unittest.TestCase):
37+
def __init__(self, *args, **kwargs):
38+
super().__init__(*args, **kwargs)
39+
40+
@slow
41+
@pytest.mark.run_slow
42+
@pytest.mark.skipif(
43+
parse(torchao.__version__) < parse("0.11.0"),
44+
reason="Quantization is only available on torchao >= 0.11.0.",
45+
)
46+
def test_glm_text_generation_with_custom_sdpa_and_kv_cache_8da4w_8we(self):
47+
model_id = "THUDM/glm-edge-1.5b-chat"
48+
prompt = "hello!"
49+
tokenizer = AutoTokenizer.from_pretrained(model_id)
50+
model = ExecuTorchModelForCausalLM.from_pretrained(
51+
model_id,
52+
recipe="xnnpack",
53+
attn_implementation="custom_sdpa",
54+
use_custom_kv_cache=True,
55+
**{"qlinear": True, "qembeeding": True},
56+
)
57+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
58+
self.assertIsInstance(model.model, ExecuTorchModule)
59+
generated_text = model.text_generation(
60+
tokenizer=tokenizer,
61+
prompt=prompt,
62+
max_seq_len=64,
63+
)
64+
logging.info(f"\nGenerated text:\n\t{generated_text}")
65+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
66+
67+
# Free memory before loading eager for quality check
68+
del model
69+
del tokenizer
70+
gc.collect()
71+
72+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))

tests/models/test_modeling_gpt2.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import gc
17+
import logging
18+
import os
19+
import unittest
20+
21+
import pytest
22+
import torchao
23+
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
24+
from packaging.version import parse
25+
from transformers import AutoTokenizer
26+
from transformers.testing_utils import slow
27+
28+
from optimum.executorch import ExecuTorchModelForCausalLM
29+
30+
from ..utils import check_causal_lm_output_quality
31+
32+
33+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
34+
35+
36+
class ExecuTorchModelIntegrationTest(unittest.TestCase):
37+
def __init__(self, *args, **kwargs):
38+
super().__init__(*args, **kwargs)
39+
40+
@slow
41+
@pytest.mark.run_slow
42+
@pytest.mark.skipif(
43+
parse(torchao.__version__) < parse("0.11.0"),
44+
reason="Quantization is only available on torchao >= 0.11.0.",
45+
)
46+
def test_gpt2sw3_text_generation_with_custom_sdpa_and_kv_cache_8da4w_8we(self):
47+
model_id = "AI-Sweden-Models/gpt-sw3-126m"
48+
prompt = "Träd är fina för att"
49+
tokenizer = AutoTokenizer.from_pretrained(model_id)
50+
model = ExecuTorchModelForCausalLM.from_pretrained(
51+
model_id,
52+
recipe="xnnpack",
53+
attn_implementation="custom_sdpa",
54+
use_custom_kv_cache=True,
55+
**{"qlinear": True, "qembeeding": True},
56+
)
57+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
58+
self.assertIsInstance(model.model, ExecuTorchModule)
59+
generated_text = model.text_generation(
60+
tokenizer=tokenizer,
61+
prompt=prompt,
62+
max_seq_len=64,
63+
)
64+
logging.info(f"\nGenerated text:\n\t{generated_text}")
65+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
66+
67+
# Free memory before loading eager for quality check
68+
del model
69+
del tokenizer
70+
gc.collect()
71+
72+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))

tests/models/test_modeling_gptj.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import gc
17+
import logging
18+
import os
19+
import unittest
20+
21+
import pytest
22+
import torchao
23+
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
24+
from packaging.version import parse
25+
from transformers import AutoConfig, AutoTokenizer
26+
from transformers.testing_utils import slow
27+
28+
from optimum.executorch import ExecuTorchModelForCausalLM
29+
30+
from ..utils import check_causal_lm_output_quality
31+
32+
33+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
34+
35+
36+
class ExecuTorchModelIntegrationTest(unittest.TestCase):
37+
def __init__(self, *args, **kwargs):
38+
super().__init__(*args, **kwargs)
39+
40+
@slow
41+
@pytest.mark.run_slow
42+
@pytest.mark.skipif(
43+
parse(torchao.__version__) < parse("0.11.0"),
44+
reason="Quantization is only available on torchao >= 0.11.0.",
45+
)
46+
def test_gptj_text_generation_with_8da4w_8we(self):
47+
model_id = "Milos/slovak-gpt-j-405M"
48+
prompt = "Tradičné jedlo na Orave sú"
49+
tokenizer = AutoTokenizer.from_pretrained(model_id)
50+
config = AutoConfig.from_pretrained(model_id)
51+
config.bos_token_id = tokenizer.bos_token_id
52+
config.eos_token_id = tokenizer.eos_token_id
53+
model = ExecuTorchModelForCausalLM.from_pretrained(
54+
model_id,
55+
config=config,
56+
recipe="xnnpack",
57+
**{"qlinear": True, "qembeeding": True},
58+
)
59+
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
60+
self.assertIsInstance(model.model, ExecuTorchModule)
61+
generated_text = model.text_generation(
62+
tokenizer=tokenizer,
63+
prompt=prompt,
64+
max_seq_len=64,
65+
)
66+
logging.info(f"\nGenerated text:\n\t{generated_text}")
67+
generated_tokens = tokenizer(generated_text, return_tensors="pt").input_ids
68+
69+
# Free memory before loading eager for quality check
70+
del model
71+
del tokenizer
72+
gc.collect()
73+
74+
self.assertTrue(check_causal_lm_output_quality(model_id, generated_tokens))

0 commit comments

Comments
 (0)