17
17
from typing import Any , Dict
18
18
19
19
import numpy as np
20
+ from huggingface_hub .constants import HF_HUB_CACHE
20
21
from PIL import Image
21
22
from transformers import AutoTokenizer
22
23
from transformers .pipelines import Pipeline
23
24
24
25
from optimum .pipelines import pipeline
26
+ from optimum .utils .testing_utils import remove_directory
27
+
28
+
29
+ GENERATE_KWARGS = {"max_new_tokens" : 10 , "min_new_tokens" : 5 , "do_sample" : True }
25
30
26
31
27
32
class ORTPipelineTest (unittest .TestCase ):
@@ -33,13 +38,11 @@ def _create_dummy_text(self) -> str:
33
38
34
39
def _create_dummy_image (self ) -> Image .Image :
35
40
"""Create dummy image input for image-based tasks"""
36
- # Create a simple RGB image
37
41
np_image = np .random .randint (0 , 256 , (224 , 224 , 3 ), dtype = np .uint8 )
38
42
return Image .fromarray (np_image )
39
43
40
44
def _create_dummy_audio (self ) -> Dict [str , Any ]:
41
45
"""Create dummy audio input for audio-based tasks"""
42
- # Create a dummy audio array (16kHz sample rate, 1 second)
43
46
sample_rate = 16000
44
47
audio_array = np .random .randn (sample_rate ).astype (np .float32 )
45
48
return {"array" : audio_array , "sampling_rate" : sample_rate }
@@ -111,7 +114,7 @@ def test_text_generation_pipeline(self):
111
114
pipe = pipeline (task = "text-generation" , accelerator = "ort" )
112
115
self .assertIsInstance (pipe , Pipeline )
113
116
text = "The future of AI is"
114
- result = pipe (text , max_new_tokens = 50 , do_sample = False )
117
+ result = pipe (text , ** GENERATE_KWARGS )
115
118
116
119
self .assertIsInstance (result , list )
117
120
self .assertGreater (len (result ), 0 )
@@ -123,7 +126,7 @@ def test_summarization_pipeline(self):
123
126
pipe = pipeline (task = "summarization" , accelerator = "ort" )
124
127
self .assertIsInstance (pipe , Pipeline )
125
128
text = "The quick brown fox jumps over the lazy dog."
126
- result = pipe (text , max_new_tokens = 50 , min_new_tokens = 10 , do_sample = False )
129
+ result = pipe (text , ** GENERATE_KWARGS )
127
130
128
131
self .assertIsInstance (result , list )
129
132
self .assertGreater (len (result ), 0 )
@@ -134,7 +137,7 @@ def test_translation_pipeline(self):
134
137
pipe = pipeline (task = "translation_en_to_de" , accelerator = "ort" )
135
138
self .assertIsInstance (pipe , Pipeline )
136
139
text = "Hello, how are you?"
137
- result = pipe (text , max_new_tokens = 50 )
140
+ result = pipe (text , ** GENERATE_KWARGS )
138
141
139
142
self .assertIsInstance (result , list )
140
143
self .assertGreater (len (result ), 0 )
@@ -145,7 +148,7 @@ def test_text2text_generation_pipeline(self):
145
148
pipe = pipeline (task = "text2text-generation" , accelerator = "ort" )
146
149
self .assertIsInstance (pipe , Pipeline )
147
150
text = "translate English to German: Hello, how are you?"
148
- result = pipe (text , max_new_tokens = 50 )
151
+ result = pipe (text , ** GENERATE_KWARGS )
149
152
150
153
self .assertIsInstance (result , list )
151
154
self .assertGreater (len (result ), 0 )
@@ -194,7 +197,7 @@ def test_image_to_text_pipeline(self):
194
197
pipe = pipeline (task = "image-to-text" , accelerator = "ort" )
195
198
self .assertIsInstance (pipe , Pipeline )
196
199
image = self ._create_dummy_image ()
197
- result = pipe (image )
200
+ result = pipe (image , generate_kwargs = GENERATE_KWARGS )
198
201
199
202
self .assertIsInstance (result , list )
200
203
self .assertGreater (len (result ), 0 )
@@ -214,7 +217,7 @@ def test_image_to_image_pipeline(self):
214
217
# """Test automatic speech recognition ORT pipeline"""
215
218
# pipe = pipeline(task="automatic-speech-recognition", accelerator="ort")
216
219
# audio = self._create_dummy_audio()
217
- # result = pipe(audio)
220
+ # result = pipe(audio, generate_kwargs=GENERATE_KWARGS )
218
221
219
222
# self.assertIsInstance(result, dict)
220
223
# self.assertIn("text", result)
@@ -246,7 +249,7 @@ def test_pipeline_with_ort_model(self):
246
249
self .assertIsInstance (result [0 ], list )
247
250
self .assertIsInstance (result [0 ][0 ], list )
248
251
249
- def test_pipeline_with_custom_model_id (self ):
252
+ def test_pipeline_with_model_id (self ):
250
253
"""Test ORT pipeline with a custom model id"""
251
254
pipe = pipeline (task = "feature-extraction" , model = "distilbert-base-cased" , accelerator = "ort" )
252
255
self .assertIsInstance (pipe , Pipeline )
@@ -265,9 +268,12 @@ def test_pipeline_with_invalid_task(self):
265
268
def test_pipeline_with_invalid_accelerator (self ):
266
269
"""Test ORT pipeline with an unsupported accelerator"""
267
270
with self .assertRaises (ValueError ) as context :
268
- _ = pipeline (task = "text-classification " , accelerator = "invalid-accelerator" )
271
+ _ = pipeline (task = "feature-extraction " , accelerator = "invalid-accelerator" )
269
272
self .assertIn ("Accelerator invalid-accelerator not recognized" , str (context .exception ))
270
273
274
+ def tearDown (self ):
275
+ remove_directory (HF_HUB_CACHE )
276
+
271
277
272
278
if __name__ == "__main__" :
273
279
unittest .main ()
0 commit comments