Skip to content

Commit 308d53a

Browse files
committed
missed out model hf fix
1 parent 20ddca1 commit 308d53a

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests/torch/models_hub_test/test_hf_transformers.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def forward(self, x):
181181

182182
tokenizer = AutoTokenizer.from_pretrained(name)
183183
config = AutoConfig.from_pretrained(name, torchscript=True)
184-
model = AutoModelForCausalLM.from_config(config)
184+
model = AutoModelForCausalLM.from_config(config, attn_implementation="eager")
185185
text = "Replace me by any text you'd like."
186186
encoded_input = tokenizer(text, return_tensors="pt")
187187
inputs_dict = dict(encoded_input)
@@ -199,7 +199,7 @@ def forward(self, x):
199199

200200
tokenizer = AutoTokenizer.from_pretrained(name)
201201
config = AutoConfig.from_pretrained(name, torchscript=True)
202-
model = AutoModelForMaskedLM.from_config(config)
202+
model = AutoModelForMaskedLM.from_config(config, attn_implementation="eager")
203203
text = "Replace me by any text you'd like."
204204
encoded_input = tokenizer(text, return_tensors="pt")
205205
example = dict(encoded_input)
@@ -209,7 +209,7 @@ def forward(self, x):
209209

210210
processor = AutoProcessor.from_pretrained(name)
211211
config = AutoConfig.from_pretrained(name, torchscript=True)
212-
model = AutoModelForImageClassification.from_config(config)
212+
model = AutoModelForImageClassification.from_config(config, attn_implementation="eager")
213213
encoded_input = processor(images=self.image, return_tensors="pt")
214214
example = dict(encoded_input)
215215
elif auto_model == "AutoModelForSeq2SeqLM":
@@ -218,7 +218,7 @@ def forward(self, x):
218218

219219
tokenizer = AutoTokenizer.from_pretrained(name)
220220
config = AutoConfig.from_pretrained(name, torchscript=True)
221-
model = AutoModelForSeq2SeqLM.from_config(config)
221+
model = AutoModelForSeq2SeqLM.from_config(config, attn_implementation="eager")
222222
inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt")
223223
decoder_inputs = tokenizer(
224224
"<pad> Studien haben gezeigt dass es hilfreich ist einen Hund zu besitzen",
@@ -232,7 +232,7 @@ def forward(self, x):
232232

233233
processor = AutoProcessor.from_pretrained(name)
234234
config = AutoConfig.from_pretrained(name, torchscript=True)
235-
model = AutoModelForSpeechSeq2Seq.from_config(config)
235+
model = AutoModelForSpeechSeq2Seq.from_config(config, attn_implementation="eager")
236236
inputs = processor(torch.randn(1000).numpy(), sampling_rate=16000, return_tensors="pt")
237237
example = dict(inputs)
238238
elif auto_model == "AutoModelForCTC":
@@ -241,7 +241,7 @@ def forward(self, x):
241241

242242
processor = AutoProcessor.from_pretrained(name)
243243
config = AutoConfig.from_pretrained(name, torchscript=True)
244-
model = AutoModelForCTC.from_config(config)
244+
model = AutoModelForCTC.from_config(config, attn_implementation="eager")
245245
input_values = processor(torch.randn(1000).numpy(), return_tensors="pt")
246246
example = dict(input_values)
247247
elif auto_model == "AutoModelForTableQuestionAnswering":
@@ -251,7 +251,7 @@ def forward(self, x):
251251

252252
tokenizer = AutoTokenizer.from_pretrained(name)
253253
config = AutoConfig.from_pretrained(name, torchscript=True)
254-
model = AutoModelForTableQuestionAnswering.from_config(config)
254+
model = AutoModelForTableQuestionAnswering.from_config(config, attn_implementation="eager")
255255
data = {
256256
"Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
257257
"Number of movies": ["87", "53", "69"],
@@ -304,7 +304,7 @@ def forward(self, x):
304304
from transformers import AutoModel
305305

306306
config = AutoConfig.from_pretrained(name, torchscript=True)
307-
model = AutoModel.from_config(config)
307+
model = AutoModel.from_config(config, attn_implementation="eager")
308308
if hasattr(model, "set_default_language"):
309309
model.set_default_language("en_XX")
310310
if example is None:

0 commit comments

Comments
 (0)