From 78f01db615b3371be259ada37285c0e3b613e34e Mon Sep 17 00:00:00 2001 From: Cedar Date: Mon, 2 Dec 2024 14:27:08 -0800 Subject: [PATCH] organize test cases --- .../kvcache/trie_attention_cache_test.py | 138 +++++++++--------- 1 file changed, 70 insertions(+), 68 deletions(-) diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py index ce0025419..0f49efda8 100644 --- a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py @@ -143,93 +143,95 @@ def print_node(node, depth=0): print_node(cache.root) -# Test sequences for parameterization basic_sequences = [ - TokenSequence(tokens=[], description="empty_sequence", expected_pages=0), - TokenSequence( - tokens=list(range(TEST_PAGE_SIZE // 2)), - description="partial_page", - expected_pages=1, - ), - TokenSequence( - tokens=list(range(TEST_PAGE_SIZE)), description="exact_page", expected_pages=1 - ), - TokenSequence( - tokens=list(range(TEST_PAGE_SIZE + 1)), - description="overflow_page", - expected_pages=2, - ), - TokenSequence( - tokens=list(range(TEST_PAGE_SIZE * 2)), - description="multiple_pages", - expected_pages=2, - ), -] - -reuse_sequences = [ - (list(range(TEST_PAGE_SIZE)), list(range(TEST_PAGE_SIZE)), "exact_match", 1, 1), - ( - list(range(TEST_PAGE_SIZE * 2)), - list(range(TEST_PAGE_SIZE * 2)), - "multi_page_match", - 2, - 2, - ), - ( - list(range(TEST_PAGE_SIZE * 2)), - list(range(TEST_PAGE_SIZE)) + list(range(100, 100 + TEST_PAGE_SIZE)), - "prefix_match", - 2, - 1, - ), - ( - list(range(TEST_PAGE_SIZE)), - list(range(50, 50 + TEST_PAGE_SIZE)), - "no_match", - 1, - 0, - ), + {"tokens": [], "description": "empty_sequence", "expected_pages": 0}, + { + "tokens": list(range(TEST_PAGE_SIZE // 2)), + "description": "partial_page", + "expected_pages": 1, + }, + { + "tokens": list(range(TEST_PAGE_SIZE)), + "description": "exact_page", + "expected_pages": 1, + }, + { + "tokens": list(range(TEST_PAGE_SIZE + 1)), + "description": "overflow_page", + "expected_pages": 2, + }, + { + "tokens": list(range(TEST_PAGE_SIZE * 2)), + "description": "multiple_pages", + "expected_pages": 2, + }, ] -@pytest.mark.parametrize("seq", basic_sequences) -def test_basic_allocation(trie_cache, seq): +@pytest.mark.parametrize("test_sequence", basic_sequences) +def test_basic_allocation(trie_cache, test_sequence): """Test basic page allocation without reuse""" - allocation = trie_cache.acquire_pages_for_tokens(seq.tokens, extra_token_slots=0) - assert len(allocation.pages) == seq.expected_pages + allocation = trie_cache.acquire_pages_for_tokens( + test_sequence["tokens"], extra_token_slots=0 + ) + assert len(allocation.pages) == test_sequence["expected_pages"] assert allocation.number_of_published_pages == 0 assert ( len(allocation.pages) - allocation.number_of_published_pages - == seq.expected_pages + == test_sequence["expected_pages"] ) allocation.publish_pages_for_tokens(allocation.tokens) allocation.release_pages() -@pytest.mark.parametrize( - "initial_tokens,reuse_tokens,description,total_pages,expected_cached", - reuse_sequences, -) -def test_page_reuse( - trie_cache, - published_sequence, - initial_tokens, - reuse_tokens, - description, - total_pages, - expected_cached, -): +reuse_sequences = [ + { + "initial_tokens": list(range(TEST_PAGE_SIZE)), + "reuse_tokens": list(range(TEST_PAGE_SIZE)), + "description": "exact_match", + "total_pages": 1, + "expected_cached": 1, + }, + { + "initial_tokens": list(range(TEST_PAGE_SIZE * 2)), + "reuse_tokens": list(range(TEST_PAGE_SIZE * 2)), + "description": "multi_page_match", + "total_pages": 2, + "expected_cached": 2, + }, + { + "initial_tokens": list(range(TEST_PAGE_SIZE * 2)), + "reuse_tokens": list(range(TEST_PAGE_SIZE)) + + list(range(100, 100 + TEST_PAGE_SIZE)), + "description": "prefix_match", + "total_pages": 2, + "expected_cached": 1, + }, + { + "initial_tokens": list(range(TEST_PAGE_SIZE)), + "reuse_tokens": list(range(50, 50 + TEST_PAGE_SIZE)), + "description": "no_match", + "total_pages": 1, + "expected_cached": 0, + }, +] + + +@pytest.mark.parametrize("test_sequences", reuse_sequences) +def test_page_reuse(trie_cache, published_sequence, test_sequences): """Test page reuse scenarios""" # Publish initial sequence - published_sequence(initial_tokens) + published_sequence(test_sequences["initial_tokens"]) # Try to reuse - allocation = trie_cache.acquire_pages_for_tokens(reuse_tokens, extra_token_slots=0) - assert len(allocation.pages) == total_pages - assert allocation.number_of_published_pages == expected_cached + allocation = trie_cache.acquire_pages_for_tokens( + test_sequences["reuse_tokens"], extra_token_slots=0 + ) + assert len(allocation.pages) == test_sequences["total_pages"] + assert allocation.number_of_published_pages == test_sequences["expected_cached"] assert ( len(allocation.pages) - allocation.number_of_published_pages - == total_pages - expected_cached + == test_sequences["total_pages"] - test_sequences["expected_cached"] ) allocation.publish_pages_for_tokens(allocation.tokens) allocation.release_pages()