Skip to content

Commit 8e32704

Browse files
authored
Merge pull request #174 from danmcp/modeladapterunits
Add model adapter unit tests
2 parents f4f8664 + c7b76b5 commit 8e32704

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

tests/test_mt_bench_model_adapter.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# Third Party
4+
import pytest
5+
6+
# First Party
7+
from instructlab.eval.mt_bench_model_adapter import (
8+
GraniteAdapter,
9+
MistralAdapter,
10+
get_conversation_template,
11+
get_model_adapter,
12+
)
13+
14+
MISTRAL_DEFAULT_MODEL_NAME = "mistral"
15+
EXAMPLE_MISTRAL_MODEL_PATHS = [
16+
"mistral",
17+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
18+
"/cache/instructlab/models/mistral-7b-instruct-v0.2.Q4_K_M.gguf",
19+
"prometheus-eval/prometheus-8x7b-v2.0",
20+
"/cache/instructlab/models/prometheus-eval/prometheus-8x7b-v2.0",
21+
]
22+
23+
GRANITE_DEFAULT_MODEL_NAME = "granite"
24+
EXAMPLE_GRANITE_MODEL_PATHS = [
25+
"granite",
26+
"instructlab/granite-7b-lab",
27+
"/cache/instructlab/models/instructlab/granite-7b-lab.gguf",
28+
"instructlab/granite-8b-lab",
29+
]
30+
31+
TEST_TUPLES = [
32+
(
33+
MISTRAL_DEFAULT_MODEL_NAME,
34+
EXAMPLE_MISTRAL_MODEL_PATHS,
35+
MistralAdapter,
36+
MISTRAL_DEFAULT_MODEL_NAME,
37+
),
38+
(
39+
GRANITE_DEFAULT_MODEL_NAME,
40+
EXAMPLE_GRANITE_MODEL_PATHS,
41+
GraniteAdapter,
42+
"ibm-generic",
43+
),
44+
]
45+
46+
47+
def test_get_model_adapter():
48+
for model, model_paths, adapter, _ in TEST_TUPLES:
49+
for model_path in model_paths:
50+
assert isinstance(get_model_adapter(model_path, model), adapter)
51+
52+
# Test default adapter overrides as expected
53+
assert isinstance(get_model_adapter("", MISTRAL_DEFAULT_MODEL_NAME), MistralAdapter)
54+
55+
56+
def test_get_model_adapter_not_found():
57+
with pytest.raises(ValueError):
58+
get_model_adapter("unknown", "unknown")
59+
60+
61+
def test_get_conversation_template():
62+
for model, model_paths, _, conv_template_name in TEST_TUPLES:
63+
for model_path in model_paths:
64+
assert (
65+
conv_template_name == get_conversation_template(model_path, model).name
66+
)

0 commit comments

Comments
 (0)