Skip to content

Commit 2d4b558

Browse files
authored
Merge pull request #140 from opentensor/feat/thewhaleking/enhance-caching
Improve CachedFetcher
2 parents a11f37d + 4808e0e commit 2d4b558

File tree

4 files changed

+204
-76
lines changed

4 files changed

+204
-76
lines changed

async_substrate_interface/async_substrate.py

Lines changed: 81 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@
5858
get_next_id,
5959
rng as random,
6060
)
61-
from async_substrate_interface.utils.cache import async_sql_lru_cache, CachedFetcher
61+
from async_substrate_interface.utils.cache import (
62+
async_sql_lru_cache,
63+
cached_fetcher,
64+
)
6265
from async_substrate_interface.utils.decoding import (
6366
_determine_if_old_runtime_call,
6467
_bt_decode_to_dict_or_list,
@@ -794,12 +797,6 @@ def __init__(
794797
self.registry_type_map = {}
795798
self.type_id_to_name = {}
796799
self._mock = _mock
797-
self._block_hash_fetcher = CachedFetcher(512, self._get_block_hash)
798-
self._parent_hash_fetcher = CachedFetcher(512, self._get_parent_block_hash)
799-
self._runtime_info_fetcher = CachedFetcher(16, self._get_block_runtime_info)
800-
self._runtime_version_for_fetcher = CachedFetcher(
801-
512, self._get_block_runtime_version_for
802-
)
803800

804801
async def __aenter__(self):
805802
if not self._mock:
@@ -1044,35 +1041,7 @@ async def init_runtime(
10441041
if not runtime:
10451042
self.last_block_hash = block_hash
10461043

1047-
runtime_block_hash = await self.get_parent_block_hash(block_hash)
1048-
1049-
runtime_info = await self.get_block_runtime_info(runtime_block_hash)
1050-
1051-
metadata, (metadata_v15, registry) = await asyncio.gather(
1052-
self.get_block_metadata(block_hash=runtime_block_hash, decode=True),
1053-
self._load_registry_at_block(block_hash=runtime_block_hash),
1054-
)
1055-
if metadata is None:
1056-
# does this ever happen?
1057-
raise SubstrateRequestException(
1058-
f"No metadata for block '{runtime_block_hash}'"
1059-
)
1060-
logger.debug(
1061-
f"Retrieved metadata and metadata v15 for {runtime_version} from Substrate node"
1062-
)
1063-
1064-
runtime = Runtime(
1065-
chain=self.chain,
1066-
runtime_config=self.runtime_config,
1067-
metadata=metadata,
1068-
type_registry=self.type_registry,
1069-
metadata_v15=metadata_v15,
1070-
runtime_info=runtime_info,
1071-
registry=registry,
1072-
)
1073-
self.runtime_cache.add_item(
1074-
runtime_version=runtime_version, runtime=runtime
1075-
)
1044+
runtime = await self.get_runtime_for_version(runtime_version, block_hash)
10761045

10771046
self.load_runtime(runtime)
10781047

@@ -1086,6 +1055,51 @@ async def init_runtime(
10861055
self.ss58_format = ss58_prefix_constant
10871056
return runtime
10881057

1058+
@cached_fetcher(max_size=16, cache_key_index=0)
1059+
async def get_runtime_for_version(
1060+
self, runtime_version: int, block_hash: Optional[str] = None
1061+
) -> Runtime:
1062+
"""
1063+
Retrieves the `Runtime` for a given runtime version at a given block hash.
1064+
Args:
1065+
runtime_version: version of the runtime (from `get_block_runtime_version_for`)
1066+
block_hash: hash of the block to query
1067+
1068+
Returns:
1069+
Runtime object for the given runtime version
1070+
"""
1071+
return await self._get_runtime_for_version(runtime_version, block_hash)
1072+
1073+
async def _get_runtime_for_version(
1074+
self, runtime_version: int, block_hash: Optional[str] = None
1075+
) -> Runtime:
1076+
runtime_block_hash = await self.get_parent_block_hash(block_hash)
1077+
runtime_info, metadata, (metadata_v15, registry) = await asyncio.gather(
1078+
self.get_block_runtime_info(runtime_block_hash),
1079+
self.get_block_metadata(block_hash=runtime_block_hash, decode=True),
1080+
self._load_registry_at_block(block_hash=runtime_block_hash),
1081+
)
1082+
if metadata is None:
1083+
# does this ever happen?
1084+
raise SubstrateRequestException(
1085+
f"No metadata for block '{runtime_block_hash}'"
1086+
)
1087+
logger.debug(
1088+
f"Retrieved metadata and metadata v15 for {runtime_version} from Substrate node"
1089+
)
1090+
1091+
runtime = Runtime(
1092+
chain=self.chain,
1093+
runtime_config=self.runtime_config,
1094+
metadata=metadata,
1095+
type_registry=self.type_registry,
1096+
metadata_v15=metadata_v15,
1097+
runtime_info=runtime_info,
1098+
registry=registry,
1099+
)
1100+
self.runtime_cache.add_item(runtime_version=runtime_version, runtime=runtime)
1101+
return runtime
1102+
10891103
async def create_storage_key(
10901104
self,
10911105
pallet: str,
@@ -1921,10 +1935,19 @@ async def get_metadata(self, block_hash=None) -> MetadataV15:
19211935

19221936
return runtime.metadata_v15
19231937

1924-
async def get_parent_block_hash(self, block_hash):
1925-
return await self._parent_hash_fetcher.execute(block_hash)
1938+
@cached_fetcher(max_size=512)
1939+
async def get_parent_block_hash(self, block_hash) -> str:
1940+
"""
1941+
Retrieves the block hash of the parent of the given block hash
1942+
Args:
1943+
block_hash: hash of the block to query
1944+
1945+
Returns:
1946+
Hash of the parent block hash, or the original block hash (if it has not parent)
1947+
"""
1948+
return await self._get_parent_block_hash(block_hash)
19261949

1927-
async def _get_parent_block_hash(self, block_hash):
1950+
async def _get_parent_block_hash(self, block_hash) -> str:
19281951
block_header = await self.rpc_request("chain_getHeader", [block_hash])
19291952

19301953
if block_header["result"] is None:
@@ -1967,25 +1990,27 @@ async def get_storage_by_key(self, block_hash: str, storage_key: str) -> Any:
19671990
"Unknown error occurred during retrieval of events"
19681991
)
19691992

1993+
@cached_fetcher(max_size=16)
19701994
async def get_block_runtime_info(self, block_hash: str) -> dict:
1971-
return await self._runtime_info_fetcher.execute(block_hash)
1995+
"""
1996+
Retrieve the runtime info of given block_hash
1997+
"""
1998+
return await self._get_block_runtime_info(block_hash)
19721999

19732000
get_block_runtime_version = get_block_runtime_info
19742001

19752002
async def _get_block_runtime_info(self, block_hash: str) -> dict:
1976-
"""
1977-
Retrieve the runtime info of given block_hash
1978-
"""
19792003
response = await self.rpc_request("state_getRuntimeVersion", [block_hash])
19802004
return response.get("result")
19812005

2006+
@cached_fetcher(max_size=512)
19822007
async def get_block_runtime_version_for(self, block_hash: str):
1983-
return await self._runtime_version_for_fetcher.execute(block_hash)
1984-
1985-
async def _get_block_runtime_version_for(self, block_hash: str):
19862008
"""
19872009
Retrieve the runtime version of the parent of a given block_hash
19882010
"""
2011+
return await self._get_block_runtime_version_for(block_hash)
2012+
2013+
async def _get_block_runtime_version_for(self, block_hash: str):
19892014
parent_block_hash = await self.get_parent_block_hash(block_hash)
19902015
runtime_info = await self.get_block_runtime_info(parent_block_hash)
19912016
if runtime_info is None:
@@ -2296,8 +2321,17 @@ async def rpc_request(
22962321
else:
22972322
raise SubstrateRequestException(result[payload_id][0])
22982323

2324+
@cached_fetcher(max_size=512)
22992325
async def get_block_hash(self, block_id: int) -> str:
2300-
return await self._block_hash_fetcher.execute(block_id)
2326+
"""
2327+
Retrieves the hash of the specified block number
2328+
Args:
2329+
block_id: block number
2330+
2331+
Returns:
2332+
Hash of the block
2333+
"""
2334+
return await self._get_block_hash(block_id)
23012335

23022336
async def _get_block_hash(self, block_id: int) -> str:
23032337
return (await self.rpc_request("chain_getBlockHash", [block_id]))["result"]

async_substrate_interface/utils/cache.py

Lines changed: 114 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import asyncio
2+
import inspect
23
from collections import OrderedDict
34
import functools
5+
import logging
46
import os
57
import pickle
68
import sqlite3
79
from pathlib import Path
8-
from typing import Callable, Any
9-
10-
import asyncstdlib as a
11-
10+
from typing import Callable, Any, Awaitable, Hashable, Optional
1211

1312
USE_CACHE = True if os.getenv("NO_CACHE") != "1" else False
1413
CACHE_LOCATION = (
@@ -19,6 +18,8 @@
1918
else ":memory:"
2019
)
2120

21+
logger = logging.getLogger("async_substrate_interface")
22+
2223

2324
def _ensure_dir():
2425
path = Path(CACHE_LOCATION).parent
@@ -70,7 +71,7 @@ def _retrieve_from_cache(c, table_name, key, chain):
7071
if result is not None:
7172
return pickle.loads(result[0])
7273
except (pickle.PickleError, sqlite3.Error) as e:
73-
print(f"Cache error: {str(e)}")
74+
logger.exception("Cache error", exc_info=e)
7475
pass
7576

7677

@@ -82,7 +83,7 @@ def _insert_into_cache(c, conn, table_name, key, result, chain):
8283
)
8384
conn.commit()
8485
except (pickle.PickleError, sqlite3.Error) as e:
85-
print(f"Cache error: {str(e)}")
86+
logger.exception("Cache error", exc_info=e)
8687
pass
8788

8889

@@ -128,7 +129,7 @@ def inner(self, *args, **kwargs):
128129

129130
def async_sql_lru_cache(maxsize=None):
130131
def decorator(func):
131-
@a.lru_cache(maxsize=maxsize)
132+
@cached_fetcher(max_size=maxsize)
132133
async def inner(self, *args, **kwargs):
133134
c, conn, table_name, key, result, chain, local_chain = (
134135
_shared_inner_fn_logic(func, self, args, kwargs)
@@ -147,6 +148,10 @@ async def inner(self, *args, **kwargs):
147148

148149

149150
class LRUCache:
151+
"""
152+
Basic Least-Recently-Used Cache, with simple methods `set` and `get`
153+
"""
154+
150155
def __init__(self, max_size: int):
151156
self.max_size = max_size
152157
self.cache = OrderedDict()
@@ -167,31 +172,121 @@ def get(self, key):
167172

168173

169174
class CachedFetcher:
170-
def __init__(self, max_size: int, method: Callable):
171-
self._inflight: dict[int, asyncio.Future] = {}
175+
"""
176+
Async caching class that allows the standard async LRU cache system, but also allows for concurrent
177+
asyncio calls (with the same args) to use the same result of a single call.
178+
179+
This should only be used for asyncio calls where the result is immutable.
180+
181+
Concept and usage:
182+
```
183+
async def fetch(self, block_hash: str) -> str:
184+
return await some_resource(block_hash)
185+
186+
a1, a2, b = await asyncio.gather(fetch("a"), fetch("a"), fetch("b"))
187+
```
188+
189+
Here, you are making three requests, but you really only need to make two I/O requests
190+
(one for "a", one for "b"), and while you wouldn't typically make a request like this directly, it's very
191+
common in using this library to inadvertently make these requests y gathering multiple resources that depend
192+
on the calls like this under the hood.
193+
194+
By using
195+
196+
```
197+
@cached_fetcher(max_size=512)
198+
async def fetch(self, block_hash: str) -> str:
199+
return await some_resource(block_hash)
200+
201+
a1, a2, b = await asyncio.gather(fetch("a"), fetch("a"), fetch("b"))
202+
```
203+
204+
You are only making two I/O calls, and a2 will simply use the result of a1 when it lands.
205+
"""
206+
207+
def __init__(
208+
self,
209+
max_size: int,
210+
method: Callable[..., Awaitable[Any]],
211+
cache_key_index: Optional[int] = 0,
212+
):
213+
"""
214+
Args:
215+
max_size: max size of the cache (in items)
216+
method: the function to cache
217+
cache_key_index: if the method takes multiple args, this is the index of that cache key in the args list
218+
(default is the first arg). By setting this to `None`, it will use all args as the cache key.
219+
"""
220+
self._inflight: dict[Hashable, asyncio.Future] = {}
172221
self._method = method
173222
self._cache = LRUCache(max_size=max_size)
223+
self._cache_key_index = cache_key_index
174224

175-
async def execute(self, single_arg: Any) -> str:
176-
if item := self._cache.get(single_arg):
225+
def make_cache_key(self, args: tuple, kwargs: dict) -> Hashable:
226+
bound = inspect.signature(self._method).bind(*args, **kwargs)
227+
bound.apply_defaults()
228+
229+
if self._cache_key_index is not None:
230+
key_name = list(bound.arguments)[self._cache_key_index]
231+
return bound.arguments[key_name]
232+
233+
return (tuple(bound.arguments.items()),)
234+
235+
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
236+
key = self.make_cache_key(args, kwargs)
237+
238+
if item := self._cache.get(key):
177239
return item
178240

179-
if single_arg in self._inflight:
180-
result = await self._inflight[single_arg]
181-
return result
241+
if key in self._inflight:
242+
return await self._inflight[key]
182243

183244
loop = asyncio.get_running_loop()
184245
future = loop.create_future()
185-
self._inflight[single_arg] = future
246+
self._inflight[key] = future
186247

187248
try:
188-
result = await self._method(single_arg)
189-
self._cache.set(single_arg, result)
249+
result = await self._method(*args, **kwargs)
250+
self._cache.set(key, result)
190251
future.set_result(result)
191252
return result
192253
except Exception as e:
193-
# Propagate errors
194254
future.set_exception(e)
195255
raise
196256
finally:
197-
self._inflight.pop(single_arg, None)
257+
self._inflight.pop(key, None)
258+
259+
260+
class _CachedFetcherMethod:
261+
"""
262+
Helper class for using CachedFetcher with method caches (rather than functions)
263+
"""
264+
265+
def __init__(self, method, max_size: int, cache_key_index: int):
266+
self.method = method
267+
self.max_size = max_size
268+
self.cache_key_index = cache_key_index
269+
self._instances = {}
270+
271+
def __get__(self, instance, owner):
272+
if instance is None:
273+
return self
274+
275+
# Cache per-instance
276+
if instance not in self._instances:
277+
bound_method = self.method.__get__(instance, owner)
278+
self._instances[instance] = CachedFetcher(
279+
max_size=self.max_size,
280+
method=bound_method,
281+
cache_key_index=self.cache_key_index,
282+
)
283+
return self._instances[instance]
284+
285+
286+
def cached_fetcher(max_size: int, cache_key_index: int = 0):
287+
"""Wrapper for CachedFetcher. See example in CachedFetcher docstring."""
288+
289+
def wrapper(method):
290+
return _CachedFetcherMethod(method, max_size, cache_key_index)
291+
292+
return wrapper

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ keywords = ["substrate", "development", "bittensor"]
88

99
dependencies = [
1010
"wheel",
11-
"asyncstdlib~=3.13.0",
1211
"bt-decode==v0.6.0",
1312
"scalecodec~=1.2.11",
1413
"websockets>=14.1",

0 commit comments

Comments
 (0)