Skip to content

Commit 4a3be57

Browse files
committed
Improves CachedFetcher
1 parent a11f37d commit 4a3be57

File tree

4 files changed

+112
-68
lines changed

4 files changed

+112
-68
lines changed

async_substrate_interface/async_substrate.py

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@
5858
get_next_id,
5959
rng as random,
6060
)
61-
from async_substrate_interface.utils.cache import async_sql_lru_cache, CachedFetcher
61+
from async_substrate_interface.utils.cache import (
62+
async_sql_lru_cache,
63+
CachedFetcher,
64+
cached_fetcher,
65+
)
6266
from async_substrate_interface.utils.decoding import (
6367
_determine_if_old_runtime_call,
6468
_bt_decode_to_dict_or_list,
@@ -794,12 +798,6 @@ def __init__(
794798
self.registry_type_map = {}
795799
self.type_id_to_name = {}
796800
self._mock = _mock
797-
self._block_hash_fetcher = CachedFetcher(512, self._get_block_hash)
798-
self._parent_hash_fetcher = CachedFetcher(512, self._get_parent_block_hash)
799-
self._runtime_info_fetcher = CachedFetcher(16, self._get_block_runtime_info)
800-
self._runtime_version_for_fetcher = CachedFetcher(
801-
512, self._get_block_runtime_version_for
802-
)
803801

804802
async def __aenter__(self):
805803
if not self._mock:
@@ -1044,35 +1042,7 @@ async def init_runtime(
10441042
if not runtime:
10451043
self.last_block_hash = block_hash
10461044

1047-
runtime_block_hash = await self.get_parent_block_hash(block_hash)
1048-
1049-
runtime_info = await self.get_block_runtime_info(runtime_block_hash)
1050-
1051-
metadata, (metadata_v15, registry) = await asyncio.gather(
1052-
self.get_block_metadata(block_hash=runtime_block_hash, decode=True),
1053-
self._load_registry_at_block(block_hash=runtime_block_hash),
1054-
)
1055-
if metadata is None:
1056-
# does this ever happen?
1057-
raise SubstrateRequestException(
1058-
f"No metadata for block '{runtime_block_hash}'"
1059-
)
1060-
logger.debug(
1061-
f"Retrieved metadata and metadata v15 for {runtime_version} from Substrate node"
1062-
)
1063-
1064-
runtime = Runtime(
1065-
chain=self.chain,
1066-
runtime_config=self.runtime_config,
1067-
metadata=metadata,
1068-
type_registry=self.type_registry,
1069-
metadata_v15=metadata_v15,
1070-
runtime_info=runtime_info,
1071-
registry=registry,
1072-
)
1073-
self.runtime_cache.add_item(
1074-
runtime_version=runtime_version, runtime=runtime
1075-
)
1045+
runtime = await self.get_runtime_for_version(runtime_version, block_hash)
10761046

10771047
self.load_runtime(runtime)
10781048

@@ -1086,6 +1056,42 @@ async def init_runtime(
10861056
self.ss58_format = ss58_prefix_constant
10871057
return runtime
10881058

1059+
@cached_fetcher(max_size=16, cache_key_index=0)
1060+
async def get_runtime_for_version(
1061+
self, runtime_version: int, block_hash: Optional[str] = None
1062+
) -> Runtime:
1063+
return await self._get_runtime_for_version(runtime_version, block_hash)
1064+
1065+
async def _get_runtime_for_version(
1066+
self, runtime_version: int, block_hash: Optional[str] = None
1067+
) -> Runtime:
1068+
runtime_block_hash = await self.get_parent_block_hash(block_hash)
1069+
runtime_info, metadata, (metadata_v15, registry) = await asyncio.gather(
1070+
self.get_block_runtime_info(runtime_block_hash),
1071+
self.get_block_metadata(block_hash=runtime_block_hash, decode=True),
1072+
self._load_registry_at_block(block_hash=runtime_block_hash),
1073+
)
1074+
if metadata is None:
1075+
# does this ever happen?
1076+
raise SubstrateRequestException(
1077+
f"No metadata for block '{runtime_block_hash}'"
1078+
)
1079+
logger.debug(
1080+
f"Retrieved metadata and metadata v15 for {runtime_version} from Substrate node"
1081+
)
1082+
1083+
runtime = Runtime(
1084+
chain=self.chain,
1085+
runtime_config=self.runtime_config,
1086+
metadata=metadata,
1087+
type_registry=self.type_registry,
1088+
metadata_v15=metadata_v15,
1089+
runtime_info=runtime_info,
1090+
registry=registry,
1091+
)
1092+
self.runtime_cache.add_item(runtime_version=runtime_version, runtime=runtime)
1093+
return runtime
1094+
10891095
async def create_storage_key(
10901096
self,
10911097
pallet: str,
@@ -1921,8 +1927,9 @@ async def get_metadata(self, block_hash=None) -> MetadataV15:
19211927

19221928
return runtime.metadata_v15
19231929

1930+
@cached_fetcher(max_size=512)
19241931
async def get_parent_block_hash(self, block_hash):
1925-
return await self._parent_hash_fetcher.execute(block_hash)
1932+
return await self._get_parent_block_hash(block_hash)
19261933

19271934
async def _get_parent_block_hash(self, block_hash):
19281935
block_header = await self.rpc_request("chain_getHeader", [block_hash])
@@ -1967,8 +1974,9 @@ async def get_storage_by_key(self, block_hash: str, storage_key: str) -> Any:
19671974
"Unknown error occurred during retrieval of events"
19681975
)
19691976

1977+
@cached_fetcher(max_size=16)
19701978
async def get_block_runtime_info(self, block_hash: str) -> dict:
1971-
return await self._runtime_info_fetcher.execute(block_hash)
1979+
return await self._get_block_runtime_info(block_hash)
19721980

19731981
get_block_runtime_version = get_block_runtime_info
19741982

@@ -1979,8 +1987,9 @@ async def _get_block_runtime_info(self, block_hash: str) -> dict:
19791987
response = await self.rpc_request("state_getRuntimeVersion", [block_hash])
19801988
return response.get("result")
19811989

1990+
@cached_fetcher(max_size=512)
19821991
async def get_block_runtime_version_for(self, block_hash: str):
1983-
return await self._runtime_version_for_fetcher.execute(block_hash)
1992+
return await self._get_block_runtime_version_for(block_hash)
19841993

19851994
async def _get_block_runtime_version_for(self, block_hash: str):
19861995
"""
@@ -2296,8 +2305,9 @@ async def rpc_request(
22962305
else:
22972306
raise SubstrateRequestException(result[payload_id][0])
22982307

2308+
@cached_fetcher(max_size=512)
22992309
async def get_block_hash(self, block_id: int) -> str:
2300-
return await self._block_hash_fetcher.execute(block_id)
2310+
return await self._get_block_hash(block_id)
23012311

23022312
async def _get_block_hash(self, block_id: int) -> str:
23032313
return (await self.rpc_request("chain_getBlockHash", [block_id]))["result"]

async_substrate_interface/utils/cache.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import asyncio
22
from collections import OrderedDict
33
import functools
4+
import logging
45
import os
56
import pickle
67
import sqlite3
78
from pathlib import Path
8-
from typing import Callable, Any
9-
10-
import asyncstdlib as a
9+
from typing import Callable, Any, Awaitable, Hashable
1110

1211

1312
USE_CACHE = True if os.getenv("NO_CACHE") != "1" else False
@@ -19,6 +18,8 @@
1918
else ":memory:"
2019
)
2120

21+
logger = logging.getLogger("async_substrate_interface")
22+
2223

2324
def _ensure_dir():
2425
path = Path(CACHE_LOCATION).parent
@@ -70,7 +71,7 @@ def _retrieve_from_cache(c, table_name, key, chain):
7071
if result is not None:
7172
return pickle.loads(result[0])
7273
except (pickle.PickleError, sqlite3.Error) as e:
73-
print(f"Cache error: {str(e)}")
74+
logger.exception("Cache error", exc_info=e)
7475
pass
7576

7677

@@ -82,7 +83,7 @@ def _insert_into_cache(c, conn, table_name, key, result, chain):
8283
)
8384
conn.commit()
8485
except (pickle.PickleError, sqlite3.Error) as e:
85-
print(f"Cache error: {str(e)}")
86+
logger.exception("Cache error", exc_info=e)
8687
pass
8788

8889

@@ -128,7 +129,7 @@ def inner(self, *args, **kwargs):
128129

129130
def async_sql_lru_cache(maxsize=None):
130131
def decorator(func):
131-
@a.lru_cache(maxsize=maxsize)
132+
@cached_fetcher(max_size=maxsize)
132133
async def inner(self, *args, **kwargs):
133134
c, conn, table_name, key, result, chain, local_chain = (
134135
_shared_inner_fn_logic(func, self, args, kwargs)
@@ -167,31 +168,65 @@ def get(self, key):
167168

168169

169170
class CachedFetcher:
170-
def __init__(self, max_size: int, method: Callable):
171-
self._inflight: dict[int, asyncio.Future] = {}
171+
def __init__(
172+
self,
173+
max_size: int,
174+
method: Callable[..., Awaitable[Any]],
175+
cache_key_index: int = 0,
176+
):
177+
self._inflight: dict[Hashable, asyncio.Future] = {}
172178
self._method = method
173179
self._cache = LRUCache(max_size=max_size)
180+
self._cache_key_index = cache_key_index
174181

175-
async def execute(self, single_arg: Any) -> str:
176-
if item := self._cache.get(single_arg):
182+
async def __call__(self, *args: Any) -> Any:
183+
key = args[self._cache_key_index]
184+
if item := self._cache.get(key):
177185
return item
178186

179-
if single_arg in self._inflight:
180-
result = await self._inflight[single_arg]
181-
return result
187+
if key in self._inflight:
188+
return await self._inflight[key]
182189

183190
loop = asyncio.get_running_loop()
184191
future = loop.create_future()
185-
self._inflight[single_arg] = future
192+
self._inflight[key] = future
186193

187194
try:
188-
result = await self._method(single_arg)
189-
self._cache.set(single_arg, result)
195+
result = await self._method(*args)
196+
self._cache.set(key, result)
190197
future.set_result(result)
191198
return result
192199
except Exception as e:
193-
# Propagate errors
194200
future.set_exception(e)
195201
raise
196202
finally:
197-
self._inflight.pop(single_arg, None)
203+
self._inflight.pop(key, None)
204+
205+
206+
class CachedFetcherMethod:
207+
def __init__(self, method, max_size: int, cache_key_index: int):
208+
self.method = method
209+
self.max_size = max_size
210+
self.cache_key_index = cache_key_index
211+
self._instances = {}
212+
213+
def __get__(self, instance, owner):
214+
if instance is None:
215+
return self
216+
217+
# Cache per-instance
218+
if instance not in self._instances:
219+
bound_method = self.method.__get__(instance, owner)
220+
self._instances[instance] = CachedFetcher(
221+
max_size=self.max_size,
222+
method=bound_method,
223+
cache_key_index=self.cache_key_index,
224+
)
225+
return self._instances[instance]
226+
227+
228+
def cached_fetcher(max_size: int, cache_key_index: int = 0):
229+
def wrapper(method):
230+
return CachedFetcherMethod(method, max_size, cache_key_index)
231+
232+
return wrapper

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ keywords = ["substrate", "development", "bittensor"]
88

99
dependencies = [
1010
"wheel",
11-
"asyncstdlib~=3.13.0",
1211
"bt-decode==v0.6.0",
1312
"scalecodec~=1.2.11",
1413
"websockets>=14.1",

tests/unit_tests/test_cache.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@ async def test_cached_fetcher_fetches_and_caches():
1313
fetcher = CachedFetcher(max_size=2, method=mock_method)
1414

1515
# First call should trigger the method
16-
result1 = await fetcher.execute("key1")
16+
result1 = await fetcher("key1")
1717
assert result1 == "result_key1"
1818
mock_method.assert_awaited_once_with("key1")
1919

2020
# Second call with the same key should use the cache
21-
result2 = await fetcher.execute("key1")
21+
result2 = await fetcher("key1")
2222
assert result2 == "result_key1"
2323
# Ensure the method was NOT called again
2424
assert mock_method.await_count == 1
2525

2626
# Third call with a new key triggers a method call
27-
result3 = await fetcher.execute("key2")
27+
result3 = await fetcher("key2")
2828
assert result3 == "result_key2"
2929
assert mock_method.await_count == 2
3030

@@ -42,11 +42,11 @@ async def slow_method(x):
4242
fetcher = CachedFetcher(max_size=2, method=slow_method)
4343

4444
# Start first request
45-
task1 = asyncio.create_task(fetcher.execute("key1"))
45+
task1 = asyncio.create_task(fetcher("key1"))
4646
await asyncio.sleep(0.1) # Let the task start and be inflight
4747

4848
# Second request for the same key while the first is in-flight
49-
task2 = asyncio.create_task(fetcher.execute("key1"))
49+
task2 = asyncio.create_task(fetcher("key1"))
5050
await asyncio.sleep(0.1)
5151

5252
# Release the inflight request
@@ -65,7 +65,7 @@ async def error_method(x):
6565
fetcher = CachedFetcher(max_size=2, method=error_method)
6666

6767
with pytest.raises(ValueError, match="Boom!"):
68-
await fetcher.execute("key1")
68+
await fetcher("key1")
6969

7070

7171
@pytest.mark.asyncio
@@ -75,12 +75,12 @@ async def test_cached_fetcher_eviction():
7575
fetcher = CachedFetcher(max_size=2, method=mock_method)
7676

7777
# Fill cache
78-
await fetcher.execute("key1")
79-
await fetcher.execute("key2")
78+
await fetcher("key1")
79+
await fetcher("key2")
8080
assert list(fetcher._cache.cache.keys()) == list(fetcher._cache.cache.keys())
8181

8282
# Insert a new key to trigger eviction
83-
await fetcher.execute("key3")
83+
await fetcher("key3")
8484
# key1 should be evicted
8585
assert "key1" not in fetcher._cache.cache
8686
assert "key2" in fetcher._cache.cache

0 commit comments

Comments
 (0)