Skip to content

Commit 78f01db

Browse files
committed
organize test cases
1 parent 481aeff commit 78f01db

File tree

1 file changed

+70
-68
lines changed

1 file changed

+70
-68
lines changed

shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py

Lines changed: 70 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -143,93 +143,95 @@ def print_node(node, depth=0):
143143
print_node(cache.root)
144144

145145

146-
# Test sequences for parameterization
147146
basic_sequences = [
148-
TokenSequence(tokens=[], description="empty_sequence", expected_pages=0),
149-
TokenSequence(
150-
tokens=list(range(TEST_PAGE_SIZE // 2)),
151-
description="partial_page",
152-
expected_pages=1,
153-
),
154-
TokenSequence(
155-
tokens=list(range(TEST_PAGE_SIZE)), description="exact_page", expected_pages=1
156-
),
157-
TokenSequence(
158-
tokens=list(range(TEST_PAGE_SIZE + 1)),
159-
description="overflow_page",
160-
expected_pages=2,
161-
),
162-
TokenSequence(
163-
tokens=list(range(TEST_PAGE_SIZE * 2)),
164-
description="multiple_pages",
165-
expected_pages=2,
166-
),
167-
]
168-
169-
reuse_sequences = [
170-
(list(range(TEST_PAGE_SIZE)), list(range(TEST_PAGE_SIZE)), "exact_match", 1, 1),
171-
(
172-
list(range(TEST_PAGE_SIZE * 2)),
173-
list(range(TEST_PAGE_SIZE * 2)),
174-
"multi_page_match",
175-
2,
176-
2,
177-
),
178-
(
179-
list(range(TEST_PAGE_SIZE * 2)),
180-
list(range(TEST_PAGE_SIZE)) + list(range(100, 100 + TEST_PAGE_SIZE)),
181-
"prefix_match",
182-
2,
183-
1,
184-
),
185-
(
186-
list(range(TEST_PAGE_SIZE)),
187-
list(range(50, 50 + TEST_PAGE_SIZE)),
188-
"no_match",
189-
1,
190-
0,
191-
),
147+
{"tokens": [], "description": "empty_sequence", "expected_pages": 0},
148+
{
149+
"tokens": list(range(TEST_PAGE_SIZE // 2)),
150+
"description": "partial_page",
151+
"expected_pages": 1,
152+
},
153+
{
154+
"tokens": list(range(TEST_PAGE_SIZE)),
155+
"description": "exact_page",
156+
"expected_pages": 1,
157+
},
158+
{
159+
"tokens": list(range(TEST_PAGE_SIZE + 1)),
160+
"description": "overflow_page",
161+
"expected_pages": 2,
162+
},
163+
{
164+
"tokens": list(range(TEST_PAGE_SIZE * 2)),
165+
"description": "multiple_pages",
166+
"expected_pages": 2,
167+
},
192168
]
193169

194170

195-
@pytest.mark.parametrize("seq", basic_sequences)
196-
def test_basic_allocation(trie_cache, seq):
171+
@pytest.mark.parametrize("test_sequence", basic_sequences)
172+
def test_basic_allocation(trie_cache, test_sequence):
197173
"""Test basic page allocation without reuse"""
198-
allocation = trie_cache.acquire_pages_for_tokens(seq.tokens, extra_token_slots=0)
199-
assert len(allocation.pages) == seq.expected_pages
174+
allocation = trie_cache.acquire_pages_for_tokens(
175+
test_sequence["tokens"], extra_token_slots=0
176+
)
177+
assert len(allocation.pages) == test_sequence["expected_pages"]
200178
assert allocation.number_of_published_pages == 0
201179
assert (
202180
len(allocation.pages) - allocation.number_of_published_pages
203-
== seq.expected_pages
181+
== test_sequence["expected_pages"]
204182
)
205183
allocation.publish_pages_for_tokens(allocation.tokens)
206184
allocation.release_pages()
207185

208186

209-
@pytest.mark.parametrize(
210-
"initial_tokens,reuse_tokens,description,total_pages,expected_cached",
211-
reuse_sequences,
212-
)
213-
def test_page_reuse(
214-
trie_cache,
215-
published_sequence,
216-
initial_tokens,
217-
reuse_tokens,
218-
description,
219-
total_pages,
220-
expected_cached,
221-
):
187+
reuse_sequences = [
188+
{
189+
"initial_tokens": list(range(TEST_PAGE_SIZE)),
190+
"reuse_tokens": list(range(TEST_PAGE_SIZE)),
191+
"description": "exact_match",
192+
"total_pages": 1,
193+
"expected_cached": 1,
194+
},
195+
{
196+
"initial_tokens": list(range(TEST_PAGE_SIZE * 2)),
197+
"reuse_tokens": list(range(TEST_PAGE_SIZE * 2)),
198+
"description": "multi_page_match",
199+
"total_pages": 2,
200+
"expected_cached": 2,
201+
},
202+
{
203+
"initial_tokens": list(range(TEST_PAGE_SIZE * 2)),
204+
"reuse_tokens": list(range(TEST_PAGE_SIZE))
205+
+ list(range(100, 100 + TEST_PAGE_SIZE)),
206+
"description": "prefix_match",
207+
"total_pages": 2,
208+
"expected_cached": 1,
209+
},
210+
{
211+
"initial_tokens": list(range(TEST_PAGE_SIZE)),
212+
"reuse_tokens": list(range(50, 50 + TEST_PAGE_SIZE)),
213+
"description": "no_match",
214+
"total_pages": 1,
215+
"expected_cached": 0,
216+
},
217+
]
218+
219+
220+
@pytest.mark.parametrize("test_sequences", reuse_sequences)
221+
def test_page_reuse(trie_cache, published_sequence, test_sequences):
222222
"""Test page reuse scenarios"""
223223
# Publish initial sequence
224-
published_sequence(initial_tokens)
224+
published_sequence(test_sequences["initial_tokens"])
225225

226226
# Try to reuse
227-
allocation = trie_cache.acquire_pages_for_tokens(reuse_tokens, extra_token_slots=0)
228-
assert len(allocation.pages) == total_pages
229-
assert allocation.number_of_published_pages == expected_cached
227+
allocation = trie_cache.acquire_pages_for_tokens(
228+
test_sequences["reuse_tokens"], extra_token_slots=0
229+
)
230+
assert len(allocation.pages) == test_sequences["total_pages"]
231+
assert allocation.number_of_published_pages == test_sequences["expected_cached"]
230232
assert (
231233
len(allocation.pages) - allocation.number_of_published_pages
232-
== total_pages - expected_cached
234+
== test_sequences["total_pages"] - test_sequences["expected_cached"]
233235
)
234236
allocation.publish_pages_for_tokens(allocation.tokens)
235237
allocation.release_pages()

0 commit comments

Comments
 (0)