Skip to content

Commit 302dfbf

Browse files
test on multiple
1 parent daabb4e commit 302dfbf

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

.github/workflows/test_pipelines.yml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,17 @@ concurrency:
1111
cancel-in-progress: true
1212

1313
env:
14+
UV_SYSTEM_PYTHON: 1
15+
UV_TORCH_BACKEND: auto
1416
TRANSFORMERS_IS_CI: true
1517

1618
jobs:
1719
build:
1820
strategy:
1921
fail-fast: false
2022
matrix:
21-
runs-on: [ubuntu-22.04]
2223
python-version: [3.9]
24+
runs-on: [ubuntu-22.04, windows-2022, macos-14]
2325

2426
runs-on: ${{ matrix.runs-on }}
2527

@@ -34,10 +36,9 @@ jobs:
3436

3537
- name: Install dependencies
3638
run: |
37-
pip install --upgrade pip
38-
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
39-
pip install --no-cache-dir optimum-onnx[onnxruntime]@git+https://github.com/huggingface/optimum-onnx.git
40-
pip install .[tests]
39+
pip install --upgrade pip uv
40+
uv pip install --no-cache-dir optimum-onnx[onnxruntime]@git+https://github.com/huggingface/optimum-onnx.git
41+
uv pip install --no-cache-dir .[tests]
4142
4243
- name: Test with pytest
4344
run: |

tests/pipelines/test_pipelines.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,16 @@
1717
from typing import Any, Dict
1818

1919
import numpy as np
20+
from huggingface_hub.constants import HF_HUB_CACHE
2021
from PIL import Image
2122
from transformers import AutoTokenizer
2223
from transformers.pipelines import Pipeline
2324

2425
from optimum.pipelines import pipeline
26+
from optimum.utils.testing_utils import remove_directory
27+
28+
29+
GENERATE_KWARGS = {"max_new_tokens": 10, "min_new_tokens": 5, "do_sample": True}
2530

2631

2732
class ORTPipelineTest(unittest.TestCase):
@@ -33,13 +38,11 @@ def _create_dummy_text(self) -> str:
3338

3439
def _create_dummy_image(self) -> Image.Image:
3540
"""Create dummy image input for image-based tasks"""
36-
# Create a simple RGB image
3741
np_image = np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
3842
return Image.fromarray(np_image)
3943

4044
def _create_dummy_audio(self) -> Dict[str, Any]:
4145
"""Create dummy audio input for audio-based tasks"""
42-
# Create a dummy audio array (16kHz sample rate, 1 second)
4346
sample_rate = 16000
4447
audio_array = np.random.randn(sample_rate).astype(np.float32)
4548
return {"array": audio_array, "sampling_rate": sample_rate}
@@ -111,7 +114,7 @@ def test_text_generation_pipeline(self):
111114
pipe = pipeline(task="text-generation", accelerator="ort")
112115
self.assertIsInstance(pipe, Pipeline)
113116
text = "The future of AI is"
114-
result = pipe(text, max_new_tokens=50, do_sample=False)
117+
result = pipe(text, **GENERATE_KWARGS)
115118

116119
self.assertIsInstance(result, list)
117120
self.assertGreater(len(result), 0)
@@ -123,7 +126,7 @@ def test_summarization_pipeline(self):
123126
pipe = pipeline(task="summarization", accelerator="ort")
124127
self.assertIsInstance(pipe, Pipeline)
125128
text = "The quick brown fox jumps over the lazy dog."
126-
result = pipe(text, max_new_tokens=50, min_new_tokens=10, do_sample=False)
129+
result = pipe(text, **GENERATE_KWARGS)
127130

128131
self.assertIsInstance(result, list)
129132
self.assertGreater(len(result), 0)
@@ -134,7 +137,7 @@ def test_translation_pipeline(self):
134137
pipe = pipeline(task="translation_en_to_de", accelerator="ort")
135138
self.assertIsInstance(pipe, Pipeline)
136139
text = "Hello, how are you?"
137-
result = pipe(text, max_new_tokens=50)
140+
result = pipe(text, **GENERATE_KWARGS)
138141

139142
self.assertIsInstance(result, list)
140143
self.assertGreater(len(result), 0)
@@ -145,7 +148,7 @@ def test_text2text_generation_pipeline(self):
145148
pipe = pipeline(task="text2text-generation", accelerator="ort")
146149
self.assertIsInstance(pipe, Pipeline)
147150
text = "translate English to German: Hello, how are you?"
148-
result = pipe(text, max_new_tokens=50)
151+
result = pipe(text, **GENERATE_KWARGS)
149152

150153
self.assertIsInstance(result, list)
151154
self.assertGreater(len(result), 0)
@@ -194,7 +197,7 @@ def test_image_to_text_pipeline(self):
194197
pipe = pipeline(task="image-to-text", accelerator="ort")
195198
self.assertIsInstance(pipe, Pipeline)
196199
image = self._create_dummy_image()
197-
result = pipe(image)
200+
result = pipe(image, generate_kwargs=GENERATE_KWARGS)
198201

199202
self.assertIsInstance(result, list)
200203
self.assertGreater(len(result), 0)
@@ -214,7 +217,7 @@ def test_image_to_image_pipeline(self):
214217
# """Test automatic speech recognition ORT pipeline"""
215218
# pipe = pipeline(task="automatic-speech-recognition", accelerator="ort")
216219
# audio = self._create_dummy_audio()
217-
# result = pipe(audio)
220+
# result = pipe(audio, generate_kwargs=GENERATE_KWARGS)
218221

219222
# self.assertIsInstance(result, dict)
220223
# self.assertIn("text", result)
@@ -246,7 +249,7 @@ def test_pipeline_with_ort_model(self):
246249
self.assertIsInstance(result[0], list)
247250
self.assertIsInstance(result[0][0], list)
248251

249-
def test_pipeline_with_custom_model_id(self):
252+
def test_pipeline_with_model_id(self):
250253
"""Test ORT pipeline with a custom model id"""
251254
pipe = pipeline(task="feature-extraction", model="distilbert-base-cased", accelerator="ort")
252255
self.assertIsInstance(pipe, Pipeline)
@@ -265,9 +268,12 @@ def test_pipeline_with_invalid_task(self):
265268
def test_pipeline_with_invalid_accelerator(self):
266269
"""Test ORT pipeline with an unsupported accelerator"""
267270
with self.assertRaises(ValueError) as context:
268-
_ = pipeline(task="text-classification", accelerator="invalid-accelerator")
271+
_ = pipeline(task="feature-extraction", accelerator="invalid-accelerator")
269272
self.assertIn("Accelerator invalid-accelerator not recognized", str(context.exception))
270273

274+
def tearDown(self):
275+
remove_directory(HF_HUB_CACHE)
276+
271277

272278
if __name__ == "__main__":
273279
unittest.main()

0 commit comments

Comments
 (0)