Skip to content

Commit 4808e0e

Browse files
committed
Handle kwargs in Cache
1 parent ac39608 commit 4808e0e

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

async_substrate_interface/utils/cache.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import asyncio
2+
import inspect
23
from collections import OrderedDict
34
import functools
45
import logging
56
import os
67
import pickle
78
import sqlite3
89
from pathlib import Path
9-
from typing import Callable, Any, Awaitable, Hashable
10-
10+
from typing import Callable, Any, Awaitable, Hashable, Optional
1111

1212
USE_CACHE = True if os.getenv("NO_CACHE") != "1" else False
1313
CACHE_LOCATION = (
@@ -208,22 +208,33 @@ def __init__(
208208
self,
209209
max_size: int,
210210
method: Callable[..., Awaitable[Any]],
211-
cache_key_index: int = 0,
211+
cache_key_index: Optional[int] = 0,
212212
):
213213
"""
214214
Args:
215215
max_size: max size of the cache (in items)
216216
method: the function to cache
217-
cache_key_index: if the method takes multiple args, only one will be used as the cache key. This is the
218-
index of that cache key in the args list (default is the first arg)
217+
cache_key_index: if the method takes multiple args, this is the index of that cache key in the args list
218+
(default is the first arg). By setting this to `None`, it will use all args as the cache key.
219219
"""
220220
self._inflight: dict[Hashable, asyncio.Future] = {}
221221
self._method = method
222222
self._cache = LRUCache(max_size=max_size)
223223
self._cache_key_index = cache_key_index
224224

225-
async def __call__(self, *args: Any) -> Any:
226-
key = args[self._cache_key_index]
225+
def make_cache_key(self, args: tuple, kwargs: dict) -> Hashable:
226+
bound = inspect.signature(self._method).bind(*args, **kwargs)
227+
bound.apply_defaults()
228+
229+
if self._cache_key_index is not None:
230+
key_name = list(bound.arguments)[self._cache_key_index]
231+
return bound.arguments[key_name]
232+
233+
return (tuple(bound.arguments.items()),)
234+
235+
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
236+
key = self.make_cache_key(args, kwargs)
237+
227238
if item := self._cache.get(key):
228239
return item
229240

@@ -235,7 +246,7 @@ async def __call__(self, *args: Any) -> Any:
235246
self._inflight[key] = future
236247

237248
try:
238-
result = await self._method(*args)
249+
result = await self._method(*args, **kwargs)
239250
self._cache.set(key, result)
240251
future.set_result(result)
241252
return result

0 commit comments

Comments
 (0)