diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index 177e0e2..eeadf18 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -2762,19 +2762,34 @@ async def rpc_request( logger.error(f"Substrate Request Exception: {result[payload_id]}") raise SubstrateRequestException(result[payload_id][0]) - @cached_fetcher(max_size=SUBSTRATE_CACHE_METHOD_SIZE) - async def get_block_hash(self, block_id: int) -> str: + async def get_block_hash(self, block_id: Optional[int]) -> str: """ - Retrieves the hash of the specified block number + Retrieves the hash of the specified block number, or the chaintip if None Args: block_id: block number Returns: Hash of the block """ + if block_id is None: + return await self.get_chain_head() + else: + if (block_hash := self.runtime_cache.blocks.get(block_id)) is not None: + return block_hash + + block_hash = await self._cached_get_block_hash(block_id) + self.runtime_cache.add_item(block_hash=block_hash, block=block_id) + return block_hash + + @cached_fetcher(max_size=SUBSTRATE_CACHE_METHOD_SIZE) + async def _cached_get_block_hash(self, block_id: int) -> str: + """ + The design of this method is as such, because it allows for an easy drop-in for a different cache, such + as is the case with DiskCachedAsyncSubstrateInterface._cached_get_block_hash + """ return await self._get_block_hash(block_id) - async def _get_block_hash(self, block_id: int) -> str: + async def _get_block_hash(self, block_id: Optional[int]) -> str: return (await self.rpc_request("chain_getBlockHash", [block_id]))["result"] async def get_chain_head(self) -> str: @@ -4250,13 +4265,25 @@ async def get_metadata_event( async def get_block_number(self, block_hash: Optional[str] = None) -> int: """Async version of `substrateinterface.base.get_block_number` method.""" - response = await self.rpc_request("chain_getHeader", [block_hash]) + if block_hash is None: + return await self._get_block_number(None) + if (block := self.runtime_cache.blocks_reverse.get(block_hash)) is not None: + return block + block = await self._cached_get_block_number(block_hash) + self.runtime_cache.add_item(block_hash=block_hash, block=block) + return block - if response["result"]: - return int(response["result"]["number"], 16) - raise SubstrateRequestException( - f"Unable to retrieve block number for {block_hash}" - ) + @cached_fetcher(max_size=SUBSTRATE_CACHE_METHOD_SIZE) + async def _cached_get_block_number(self, block_hash: str) -> int: + """ + The design of this method is as such, because it allows for an easy drop-in for a different cache, such + as is the case with DiskCachedAsyncSubstrateInterface._cached_get_block_number + """ + return await self._get_block_number(block_hash=block_hash) + + async def _get_block_number(self, block_hash: Optional[str]) -> int: + response = await self.rpc_request("chain_getHeader", [block_hash]) + return int(response["result"]["number"], 16) async def close(self): """ @@ -4351,9 +4378,13 @@ async def get_block_runtime_version_for(self, block_hash: str): return await self._get_block_runtime_version_for(block_hash) @async_sql_lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE) - async def get_block_hash(self, block_id: int) -> str: + async def _cached_get_block_hash(self, block_id: int) -> str: return await self._get_block_hash(block_id) + @async_sql_lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE) + async def _cached_get_block_number(self, block_hash: str) -> int: + return await self._get_block_number(block_hash=block_hash) + async def get_async_substrate_interface( url: str, diff --git a/async_substrate_interface/sync_substrate.py b/async_substrate_interface/sync_substrate.py index 5b6db72..2172b74 100644 --- a/async_substrate_interface/sync_substrate.py +++ b/async_substrate_interface/sync_substrate.py @@ -2052,8 +2052,21 @@ def rpc_request( else: raise SubstrateRequestException(result[payload_id][0]) + def get_block_hash(self, block_id: Optional[int]) -> str: + """ + Retrieves the block hash for a given block number, or the chaintip hash if None + """ + if block_id is None: + return self.get_chain_head() + else: + if (block_hash := self.runtime_cache.blocks.get(block_id)) is not None: + return block_hash + block_hash = self._get_block_hash(block_id) + self.runtime_cache.add_item(block_hash=block_hash, block=block_id) + return block_hash + @functools.lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE) - def get_block_hash(self, block_id: int) -> str: + def _get_block_hash(self, block_id: int) -> str: return self.rpc_request("chain_getBlockHash", [block_id])["result"] def get_chain_head(self) -> str: @@ -3380,15 +3393,27 @@ def get_metadata_event( return self._get_metadata_event(module_name, event_name, runtime) def get_block_number(self, block_hash: Optional[str] = None) -> int: - """Async version of `substrateinterface.base.get_block_number` method.""" - response = self.rpc_request("chain_getHeader", [block_hash]) - - if response["result"]: - return int(response["result"]["number"], 16) + """ + Retrieves the block number for a given block hash or chaintip. + """ + if block_hash is None: + return self._get_block_number(None) else: - raise SubstrateRequestException( - f"Unable to determine block number for {block_hash}" - ) + if ( + block_number := self.runtime_cache.blocks_reverse.get(block_hash) + ) is not None: + return block_number + block_number = self._cached_get_block_number(block_hash=block_hash) + self.runtime_cache.add_item(block_hash=block_hash, block=block_number) + return block_number + + @functools.lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE) + def _cached_get_block_number(self, block_hash: Optional[str]) -> int: + return self._get_block_number(block_hash=block_hash) + + def _get_block_number(self, block_hash: Optional[str]) -> int: + response = self.rpc_request("chain_getHeader", [block_hash]) + return int(response["result"]["number"], 16) def close(self): """ @@ -3404,6 +3429,7 @@ def close(self): self.get_block_runtime_info.cache_clear() self.get_block_runtime_version_for.cache_clear() self.supports_rpc_method.cache_clear() - self.get_block_hash.cache_clear() + self._get_block_hash.cache_clear() + self._cached_get_block_number.cache_clear() encode_scale = SubstrateMixin._encode_scale diff --git a/async_substrate_interface/types.py b/async_substrate_interface/types.py index 842e260..7af5e83 100644 --- a/async_substrate_interface/types.py +++ b/async_substrate_interface/types.py @@ -2,7 +2,7 @@ import logging import os from abc import ABC -from collections import defaultdict, deque +from collections import defaultdict, deque, OrderedDict from collections.abc import Iterable from contextlib import suppress from dataclasses import dataclass @@ -48,6 +48,8 @@ class RuntimeCache: def __init__(self, known_versions: Optional[Sequence[tuple[int, int]]] = None): # {block: block_hash, ...} self.blocks: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE) + # {block_hash: block, ...} + self.blocks_reverse: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE) # {block_hash: specVersion, ...} self.block_hashes: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE) # {specVersion: Runtime, ...} @@ -87,7 +89,7 @@ def add_known_versions(self, known_versions: Sequence[tuple[int, int]]): def add_item( self, - runtime: "Runtime", + runtime: Optional["Runtime"] = None, block: Optional[int] = None, block_hash: Optional[str] = None, runtime_version: Optional[int] = None, @@ -95,13 +97,15 @@ def add_item( """ Adds a Runtime object to the cache mapped to its version, block number, and/or block hash. """ - self.last_used = runtime + if runtime is not None: + self.last_used = runtime + if runtime_version is not None: + self.versions.set(runtime_version, runtime) if block is not None and block_hash is not None: self.blocks.set(block, block_hash) + self.blocks_reverse.set(block_hash, block) if block_hash is not None and runtime_version is not None: self.block_hashes.set(block_hash, runtime_version) - if runtime_version is not None: - self.versions.set(runtime_version, runtime) def retrieve( self, @@ -114,16 +118,24 @@ def retrieve( Retrieval happens in this order. If no Runtime is found mapped to any of your supplied keys, returns `None`. """ # No reason to do this lookup if the runtime version is already supplied in this call - if block is not None and runtime_version is None and self._known_version_blocks: - # _known_version_blocks excludes the last item (see note in `add_known_versions`) - idx = bisect.bisect_right(self._known_version_blocks, block) - 1 - if idx >= 0: - runtime_version = self.known_versions[idx][1] + if runtime_version is None and self._known_version_blocks: + if block is not None: + block_ = block + elif block_hash is not None: + block_ = self.blocks_reverse.get(block_hash) + else: + block_ = None + if block_ is not None: + # _known_version_blocks excludes the last item (see note in `add_known_versions`) + idx = bisect.bisect_right(self._known_version_blocks, block_) - 1 + if idx >= 0: + runtime_version = self.known_versions[idx][1] runtime = None if block is not None: if block_hash is not None: self.blocks.set(block, block_hash) + self.blocks_reverse.set(block_hash, block) if runtime_version is not None: self.block_hashes.set(block_hash, runtime_version) with suppress(AttributeError): @@ -158,6 +170,9 @@ async def load_from_disk(self, chain_endpoint: str): else: logger.debug("Found runtime mappings in disk cache") self.blocks.cache = block_mapping + self.blocks_reverse.cache = OrderedDict( + {v: k for k, v in block_mapping.items()} + ) self.block_hashes.cache = block_hash_mapping for x, y in runtime_version_mapping.items(): self.versions.cache[x] = Runtime.deserialize(y) diff --git a/async_substrate_interface/utils/cache.py b/async_substrate_interface/utils/cache.py index 431a430..8de077b 100644 --- a/async_substrate_interface/utils/cache.py +++ b/async_substrate_interface/utils/cache.py @@ -108,7 +108,9 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any] await self._db.commit() return result - async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]: + async def load_runtime_cache( + self, chain: str + ) -> tuple[OrderedDict[int, str], OrderedDict[str, int], OrderedDict[int, dict]]: async with self._lock: if not self._db: _ensure_dir() @@ -125,7 +127,7 @@ async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]: async with self._lock: local_chain = await self._create_if_not_exists(chain, table) if local_chain: - return {}, {}, {} + return block_mapping, block_hash_mapping, version_mapping for table_name, mapping in tables.items(): try: async with self._lock: @@ -143,7 +145,7 @@ async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]: mapping[key] = runtime except (pickle.PickleError, sqlite3.Error) as e: logger.exception("Cache error", exc_info=e) - return {}, {}, {} + return block_mapping, block_hash_mapping, version_mapping return block_mapping, block_hash_mapping, version_mapping async def dump_runtime_cache( diff --git a/tests/unit_tests/asyncio_/test_env_vars.py b/tests/unit_tests/asyncio_/test_env_vars.py index 10a0933..3e35565 100644 --- a/tests/unit_tests/asyncio_/test_env_vars.py +++ b/tests/unit_tests/asyncio_/test_env_vars.py @@ -10,7 +10,7 @@ def test_env_vars(monkeypatch): assert asi.get_block_runtime_info._max_size == 9 assert asi.get_parent_block_hash._max_size == 10 assert asi.get_block_runtime_version_for._max_size == 10 - assert asi.get_block_hash._max_size == 10 + assert asi._cached_get_block_hash._max_size == 10 def test_defaults(): @@ -20,4 +20,4 @@ def test_defaults(): assert asi.get_block_runtime_info._max_size == 16 assert asi.get_parent_block_hash._max_size == 512 assert asi.get_block_runtime_version_for._max_size == 512 - assert asi.get_block_hash._max_size == 512 + assert asi._cached_get_block_hash._max_size == 512 diff --git a/tests/unit_tests/asyncio_/test_substrate_interface.py b/tests/unit_tests/asyncio_/test_substrate_interface.py index 721804b..c6f8544 100644 --- a/tests/unit_tests/asyncio_/test_substrate_interface.py +++ b/tests/unit_tests/asyncio_/test_substrate_interface.py @@ -175,3 +175,73 @@ async def test_memory_leak(): f"Loop {i}: diff={total_diff / 1024:.2f} KiB, current={current / 1024:.2f} KiB, " f"peak={peak / 1024:.2f} KiB" ) + + +class TestGetBlockHash: + @pytest.fixture + def substrate(self): + s = AsyncSubstrateInterface("ws://localhost", _mock=True) + s.runtime_cache = MagicMock() + s._cached_get_block_hash = AsyncMock(return_value="0xCACHED") + s.get_chain_head = AsyncMock(return_value="0xHEAD") + return s + + @pytest.mark.asyncio + async def test_none_block_id_returns_chain_head(self, substrate): + result = await substrate.get_block_hash(None) + assert result == "0xHEAD" + substrate.get_chain_head.assert_awaited_once() + substrate._cached_get_block_hash.assert_not_awaited() + + @pytest.mark.asyncio + async def test_cache_hit_returns_cached_hash(self, substrate): + substrate.runtime_cache.blocks.get.return_value = "0xFROMCACHE" + result = await substrate.get_block_hash(42) + assert result == "0xFROMCACHE" + substrate.runtime_cache.blocks.get.assert_called_once_with(42) + substrate._cached_get_block_hash.assert_not_awaited() + + @pytest.mark.asyncio + async def test_cache_miss_fetches_and_stores(self, substrate): + substrate.runtime_cache.blocks.get.return_value = None + result = await substrate.get_block_hash(42) + assert result == "0xCACHED" + substrate._cached_get_block_hash.assert_awaited_once_with(42) + substrate.runtime_cache.add_item.assert_called_once_with( + block_hash="0xCACHED", block=42 + ) + + +class TestGetBlockNumber: + @pytest.fixture + def substrate(self): + s = AsyncSubstrateInterface("ws://localhost", _mock=True) + s.runtime_cache = MagicMock() + s._cached_get_block_number = AsyncMock(return_value=100) + s._get_block_number = AsyncMock(return_value=99) + return s + + @pytest.mark.asyncio + async def test_none_block_hash_calls_get_block_number_directly(self, substrate): + result = await substrate.get_block_number(None) + assert result == 99 + substrate._get_block_number.assert_awaited_once_with(None) + substrate._cached_get_block_number.assert_not_awaited() + + @pytest.mark.asyncio + async def test_cache_hit_returns_cached_number(self, substrate): + substrate.runtime_cache.blocks_reverse.get.return_value = 42 + result = await substrate.get_block_number("0xABC") + assert result == 42 + substrate.runtime_cache.blocks_reverse.get.assert_called_once_with("0xABC") + substrate._cached_get_block_number.assert_not_awaited() + + @pytest.mark.asyncio + async def test_cache_miss_fetches_and_stores(self, substrate): + substrate.runtime_cache.blocks_reverse.get.return_value = None + result = await substrate.get_block_number("0xABC") + assert result == 100 + substrate._cached_get_block_number.assert_awaited_once_with("0xABC") + substrate.runtime_cache.add_item.assert_called_once_with( + block_hash="0xABC", block=100 + ) diff --git a/tests/unit_tests/sync/test_env_vars.py b/tests/unit_tests/sync/test_env_vars.py index 05d5ded..e53991c 100644 --- a/tests/unit_tests/sync/test_env_vars.py +++ b/tests/unit_tests/sync/test_env_vars.py @@ -10,7 +10,7 @@ def test_env_vars(monkeypatch): assert asi.get_block_runtime_info.cache_parameters()["maxsize"] == 9 assert asi.get_parent_block_hash.cache_parameters()["maxsize"] == 10 assert asi.get_block_runtime_version_for.cache_parameters()["maxsize"] == 10 - assert asi.get_block_hash.cache_parameters()["maxsize"] == 10 + assert asi._get_block_hash.cache_parameters()["maxsize"] == 10 def test_defaults(): @@ -20,4 +20,4 @@ def test_defaults(): assert asi.get_block_runtime_info.cache_parameters()["maxsize"] == 16 assert asi.get_parent_block_hash.cache_parameters()["maxsize"] == 512 assert asi.get_block_runtime_version_for.cache_parameters()["maxsize"] == 512 - assert asi.get_block_hash.cache_parameters()["maxsize"] == 512 + assert asi._get_block_hash.cache_parameters()["maxsize"] == 512 diff --git a/tests/unit_tests/sync/test_substrate_interface.py b/tests/unit_tests/sync/test_substrate_interface.py index 54a5b7d..491ace4 100644 --- a/tests/unit_tests/sync/test_substrate_interface.py +++ b/tests/unit_tests/sync/test_substrate_interface.py @@ -122,3 +122,71 @@ def test_memory_leak(): f"Loop {i}: diff={total_diff / 1024:.2f} KiB, current={current / 1024:.2f} KiB, " f"peak={peak / 1024:.2f} KiB" ) + + +class TestGetBlockHash: + def _make_substrate(self): + s = SubstrateInterface("ws://localhost", _mock=True) + s.runtime_cache = MagicMock() + s._get_block_hash = MagicMock(return_value="0xCACHED") + s.get_chain_head = MagicMock(return_value="0xHEAD") + return s + + def test_none_block_id_returns_chain_head(self): + substrate = self._make_substrate() + result = substrate.get_block_hash(None) + assert result == "0xHEAD" + substrate.get_chain_head.assert_called_once() + substrate._get_block_hash.assert_not_called() + + def test_cache_hit_returns_cached_hash(self): + substrate = self._make_substrate() + substrate.runtime_cache.blocks.get.return_value = "0xFROMCACHE" + result = substrate.get_block_hash(42) + assert result == "0xFROMCACHE" + substrate.runtime_cache.blocks.get.assert_called_once_with(42) + substrate._get_block_hash.assert_not_called() + + def test_cache_miss_fetches_and_stores(self): + substrate = self._make_substrate() + substrate.runtime_cache.blocks.get.return_value = None + result = substrate.get_block_hash(42) + assert result == "0xCACHED" + substrate._get_block_hash.assert_called_once_with(42) + substrate.runtime_cache.add_item.assert_called_once_with( + block_hash="0xCACHED", block=42 + ) + + +class TestGetBlockNumber: + def _make_substrate(self): + s = SubstrateInterface("ws://localhost", _mock=True) + s.runtime_cache = MagicMock() + s._cached_get_block_number = MagicMock(return_value=100) + s._get_block_number = MagicMock(return_value=99) + return s + + def test_none_block_hash_calls_get_block_number_directly(self): + substrate = self._make_substrate() + result = substrate.get_block_number(None) + assert result == 99 + substrate._get_block_number.assert_called_once_with(None) + substrate._cached_get_block_number.assert_not_called() + + def test_cache_hit_returns_cached_number(self): + substrate = self._make_substrate() + substrate.runtime_cache.blocks_reverse.get.return_value = 42 + result = substrate.get_block_number("0xABC") + assert result == 42 + substrate.runtime_cache.blocks_reverse.get.assert_called_once_with("0xABC") + substrate._cached_get_block_number.assert_not_called() + + def test_cache_miss_fetches_and_stores(self): + substrate = self._make_substrate() + substrate.runtime_cache.blocks_reverse.get.return_value = None + result = substrate.get_block_number("0xABC") + assert result == 100 + substrate._cached_get_block_number.assert_called_once_with(block_hash="0xABC") + substrate.runtime_cache.add_item.assert_called_once_with( + block_hash="0xABC", block=100 + )