Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2762,19 +2762,34 @@ async def rpc_request(
logger.error(f"Substrate Request Exception: {result[payload_id]}")
raise SubstrateRequestException(result[payload_id][0])

@cached_fetcher(max_size=SUBSTRATE_CACHE_METHOD_SIZE)
async def get_block_hash(self, block_id: int) -> str:
async def get_block_hash(self, block_id: Optional[int]) -> str:
"""
Retrieves the hash of the specified block number
Retrieves the hash of the specified block number, or the chaintip if None
Args:
block_id: block number

Returns:
Hash of the block
"""
if block_id is None:
return await self.get_chain_head()
else:
if (block_hash := self.runtime_cache.blocks.get(block_id)) is not None:
return block_hash

block_hash = await self._cached_get_block_hash(block_id)
self.runtime_cache.add_item(block_hash=block_hash, block=block_id)
return block_hash

@cached_fetcher(max_size=SUBSTRATE_CACHE_METHOD_SIZE)
async def _cached_get_block_hash(self, block_id: int) -> str:
"""
The design of this method is as such, because it allows for an easy drop-in for a different cache, such
as is the case with DiskCachedAsyncSubstrateInterface._cached_get_block_hash
"""
return await self._get_block_hash(block_id)

async def _get_block_hash(self, block_id: int) -> str:
async def _get_block_hash(self, block_id: Optional[int]) -> str:
return (await self.rpc_request("chain_getBlockHash", [block_id]))["result"]

async def get_chain_head(self) -> str:
Expand Down Expand Up @@ -4250,13 +4265,25 @@ async def get_metadata_event(

async def get_block_number(self, block_hash: Optional[str] = None) -> int:
"""Async version of `substrateinterface.base.get_block_number` method."""
response = await self.rpc_request("chain_getHeader", [block_hash])
if block_hash is None:
return await self._get_block_number(None)
if (block := self.runtime_cache.blocks_reverse.get(block_hash)) is not None:
return block
block = await self._cached_get_block_number(block_hash)
self.runtime_cache.add_item(block_hash=block_hash, block=block)
return block

if response["result"]:
return int(response["result"]["number"], 16)
raise SubstrateRequestException(
f"Unable to retrieve block number for {block_hash}"
)
@cached_fetcher(max_size=SUBSTRATE_CACHE_METHOD_SIZE)
async def _cached_get_block_number(self, block_hash: str) -> int:
"""
The design of this method is as such, because it allows for an easy drop-in for a different cache, such
as is the case with DiskCachedAsyncSubstrateInterface._cached_get_block_number
"""
return await self._get_block_number(block_hash=block_hash)

async def _get_block_number(self, block_hash: Optional[str]) -> int:
response = await self.rpc_request("chain_getHeader", [block_hash])
return int(response["result"]["number"], 16)

async def close(self):
"""
Expand Down Expand Up @@ -4351,9 +4378,13 @@ async def get_block_runtime_version_for(self, block_hash: str):
return await self._get_block_runtime_version_for(block_hash)

@async_sql_lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE)
async def get_block_hash(self, block_id: int) -> str:
async def _cached_get_block_hash(self, block_id: int) -> str:
return await self._get_block_hash(block_id)

@async_sql_lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE)
async def _cached_get_block_number(self, block_hash: str) -> int:
return await self._get_block_number(block_hash=block_hash)


async def get_async_substrate_interface(
url: str,
Expand Down
46 changes: 36 additions & 10 deletions async_substrate_interface/sync_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,8 +2052,21 @@ def rpc_request(
else:
raise SubstrateRequestException(result[payload_id][0])

def get_block_hash(self, block_id: Optional[int]) -> str:
"""
Retrieves the block hash for a given block number, or the chaintip hash if None
"""
if block_id is None:
return self.get_chain_head()
else:
if (block_hash := self.runtime_cache.blocks.get(block_id)) is not None:
return block_hash
block_hash = self._get_block_hash(block_id)
self.runtime_cache.add_item(block_hash=block_hash, block=block_id)
return block_hash

@functools.lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE)
def get_block_hash(self, block_id: int) -> str:
def _get_block_hash(self, block_id: int) -> str:
return self.rpc_request("chain_getBlockHash", [block_id])["result"]

def get_chain_head(self) -> str:
Expand Down Expand Up @@ -3380,15 +3393,27 @@ def get_metadata_event(
return self._get_metadata_event(module_name, event_name, runtime)

def get_block_number(self, block_hash: Optional[str] = None) -> int:
"""Async version of `substrateinterface.base.get_block_number` method."""
response = self.rpc_request("chain_getHeader", [block_hash])

if response["result"]:
return int(response["result"]["number"], 16)
"""
Retrieves the block number for a given block hash or chaintip.
"""
if block_hash is None:
return self._get_block_number(None)
else:
raise SubstrateRequestException(
f"Unable to determine block number for {block_hash}"
)
if (
block_number := self.runtime_cache.blocks_reverse.get(block_hash)
) is not None:
return block_number
block_number = self._cached_get_block_number(block_hash=block_hash)
self.runtime_cache.add_item(block_hash=block_hash, block=block_number)
return block_number

@functools.lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE)
def _cached_get_block_number(self, block_hash: Optional[str]) -> int:
return self._get_block_number(block_hash=block_hash)

def _get_block_number(self, block_hash: Optional[str]) -> int:
response = self.rpc_request("chain_getHeader", [block_hash])
return int(response["result"]["number"], 16)

def close(self):
"""
Expand All @@ -3404,6 +3429,7 @@ def close(self):
self.get_block_runtime_info.cache_clear()
self.get_block_runtime_version_for.cache_clear()
self.supports_rpc_method.cache_clear()
self.get_block_hash.cache_clear()
self._get_block_hash.cache_clear()
self._cached_get_block_number.cache_clear()

encode_scale = SubstrateMixin._encode_scale
35 changes: 25 additions & 10 deletions async_substrate_interface/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from abc import ABC
from collections import defaultdict, deque
from collections import defaultdict, deque, OrderedDict
from collections.abc import Iterable
from contextlib import suppress
from dataclasses import dataclass
Expand Down Expand Up @@ -48,6 +48,8 @@ class RuntimeCache:
def __init__(self, known_versions: Optional[Sequence[tuple[int, int]]] = None):
# {block: block_hash, ...}
self.blocks: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE)
# {block_hash: block, ...}
self.blocks_reverse: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE)
# {block_hash: specVersion, ...}
self.block_hashes: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE)
# {specVersion: Runtime, ...}
Expand Down Expand Up @@ -87,21 +89,23 @@ def add_known_versions(self, known_versions: Sequence[tuple[int, int]]):

def add_item(
self,
runtime: "Runtime",
runtime: Optional["Runtime"] = None,
block: Optional[int] = None,
block_hash: Optional[str] = None,
runtime_version: Optional[int] = None,
) -> None:
"""
Adds a Runtime object to the cache mapped to its version, block number, and/or block hash.
"""
self.last_used = runtime
if runtime is not None:
self.last_used = runtime
if runtime_version is not None:
self.versions.set(runtime_version, runtime)
if block is not None and block_hash is not None:
self.blocks.set(block, block_hash)
self.blocks_reverse.set(block_hash, block)
if block_hash is not None and runtime_version is not None:
self.block_hashes.set(block_hash, runtime_version)
if runtime_version is not None:
self.versions.set(runtime_version, runtime)

def retrieve(
self,
Expand All @@ -114,16 +118,24 @@ def retrieve(
Retrieval happens in this order. If no Runtime is found mapped to any of your supplied keys, returns `None`.
"""
# No reason to do this lookup if the runtime version is already supplied in this call
if block is not None and runtime_version is None and self._known_version_blocks:
# _known_version_blocks excludes the last item (see note in `add_known_versions`)
idx = bisect.bisect_right(self._known_version_blocks, block) - 1
if idx >= 0:
runtime_version = self.known_versions[idx][1]
if runtime_version is None and self._known_version_blocks:
if block is not None:
block_ = block
elif block_hash is not None:
block_ = self.blocks_reverse.get(block_hash)
else:
block_ = None
if block_ is not None:
# _known_version_blocks excludes the last item (see note in `add_known_versions`)
idx = bisect.bisect_right(self._known_version_blocks, block_) - 1
if idx >= 0:
runtime_version = self.known_versions[idx][1]

runtime = None
if block is not None:
if block_hash is not None:
self.blocks.set(block, block_hash)
self.blocks_reverse.set(block_hash, block)
if runtime_version is not None:
self.block_hashes.set(block_hash, runtime_version)
with suppress(AttributeError):
Expand Down Expand Up @@ -158,6 +170,9 @@ async def load_from_disk(self, chain_endpoint: str):
else:
logger.debug("Found runtime mappings in disk cache")
self.blocks.cache = block_mapping
self.blocks_reverse.cache = OrderedDict(
{v: k for k, v in block_mapping.items()}
)
self.block_hashes.cache = block_hash_mapping
for x, y in runtime_version_mapping.items():
self.versions.cache[x] = Runtime.deserialize(y)
Expand Down
8 changes: 5 additions & 3 deletions async_substrate_interface/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any]
await self._db.commit()
return result

async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]:
async def load_runtime_cache(
self, chain: str
) -> tuple[OrderedDict[int, str], OrderedDict[str, int], OrderedDict[int, dict]]:
async with self._lock:
if not self._db:
_ensure_dir()
Expand All @@ -125,7 +127,7 @@ async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]:
async with self._lock:
local_chain = await self._create_if_not_exists(chain, table)
if local_chain:
return {}, {}, {}
return block_mapping, block_hash_mapping, version_mapping
for table_name, mapping in tables.items():
try:
async with self._lock:
Expand All @@ -143,7 +145,7 @@ async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]:
mapping[key] = runtime
except (pickle.PickleError, sqlite3.Error) as e:
logger.exception("Cache error", exc_info=e)
return {}, {}, {}
return block_mapping, block_hash_mapping, version_mapping
return block_mapping, block_hash_mapping, version_mapping

async def dump_runtime_cache(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/asyncio_/test_env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_env_vars(monkeypatch):
assert asi.get_block_runtime_info._max_size == 9
assert asi.get_parent_block_hash._max_size == 10
assert asi.get_block_runtime_version_for._max_size == 10
assert asi.get_block_hash._max_size == 10
assert asi._cached_get_block_hash._max_size == 10


def test_defaults():
Expand All @@ -20,4 +20,4 @@ def test_defaults():
assert asi.get_block_runtime_info._max_size == 16
assert asi.get_parent_block_hash._max_size == 512
assert asi.get_block_runtime_version_for._max_size == 512
assert asi.get_block_hash._max_size == 512
assert asi._cached_get_block_hash._max_size == 512
70 changes: 70 additions & 0 deletions tests/unit_tests/asyncio_/test_substrate_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,73 @@ async def test_memory_leak():
f"Loop {i}: diff={total_diff / 1024:.2f} KiB, current={current / 1024:.2f} KiB, "
f"peak={peak / 1024:.2f} KiB"
)


class TestGetBlockHash:
@pytest.fixture
def substrate(self):
s = AsyncSubstrateInterface("ws://localhost", _mock=True)
s.runtime_cache = MagicMock()
s._cached_get_block_hash = AsyncMock(return_value="0xCACHED")
s.get_chain_head = AsyncMock(return_value="0xHEAD")
return s

@pytest.mark.asyncio
async def test_none_block_id_returns_chain_head(self, substrate):
result = await substrate.get_block_hash(None)
assert result == "0xHEAD"
substrate.get_chain_head.assert_awaited_once()
substrate._cached_get_block_hash.assert_not_awaited()

@pytest.mark.asyncio
async def test_cache_hit_returns_cached_hash(self, substrate):
substrate.runtime_cache.blocks.get.return_value = "0xFROMCACHE"
result = await substrate.get_block_hash(42)
assert result == "0xFROMCACHE"
substrate.runtime_cache.blocks.get.assert_called_once_with(42)
substrate._cached_get_block_hash.assert_not_awaited()

@pytest.mark.asyncio
async def test_cache_miss_fetches_and_stores(self, substrate):
substrate.runtime_cache.blocks.get.return_value = None
result = await substrate.get_block_hash(42)
assert result == "0xCACHED"
substrate._cached_get_block_hash.assert_awaited_once_with(42)
substrate.runtime_cache.add_item.assert_called_once_with(
block_hash="0xCACHED", block=42
)


class TestGetBlockNumber:
@pytest.fixture
def substrate(self):
s = AsyncSubstrateInterface("ws://localhost", _mock=True)
s.runtime_cache = MagicMock()
s._cached_get_block_number = AsyncMock(return_value=100)
s._get_block_number = AsyncMock(return_value=99)
return s

@pytest.mark.asyncio
async def test_none_block_hash_calls_get_block_number_directly(self, substrate):
result = await substrate.get_block_number(None)
assert result == 99
substrate._get_block_number.assert_awaited_once_with(None)
substrate._cached_get_block_number.assert_not_awaited()

@pytest.mark.asyncio
async def test_cache_hit_returns_cached_number(self, substrate):
substrate.runtime_cache.blocks_reverse.get.return_value = 42
result = await substrate.get_block_number("0xABC")
assert result == 42
substrate.runtime_cache.blocks_reverse.get.assert_called_once_with("0xABC")
substrate._cached_get_block_number.assert_not_awaited()

@pytest.mark.asyncio
async def test_cache_miss_fetches_and_stores(self, substrate):
substrate.runtime_cache.blocks_reverse.get.return_value = None
result = await substrate.get_block_number("0xABC")
assert result == 100
substrate._cached_get_block_number.assert_awaited_once_with("0xABC")
substrate.runtime_cache.add_item.assert_called_once_with(
block_hash="0xABC", block=100
)
4 changes: 2 additions & 2 deletions tests/unit_tests/sync/test_env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_env_vars(monkeypatch):
assert asi.get_block_runtime_info.cache_parameters()["maxsize"] == 9
assert asi.get_parent_block_hash.cache_parameters()["maxsize"] == 10
assert asi.get_block_runtime_version_for.cache_parameters()["maxsize"] == 10
assert asi.get_block_hash.cache_parameters()["maxsize"] == 10
assert asi._get_block_hash.cache_parameters()["maxsize"] == 10


def test_defaults():
Expand All @@ -20,4 +20,4 @@ def test_defaults():
assert asi.get_block_runtime_info.cache_parameters()["maxsize"] == 16
assert asi.get_parent_block_hash.cache_parameters()["maxsize"] == 512
assert asi.get_block_runtime_version_for.cache_parameters()["maxsize"] == 512
assert asi.get_block_hash.cache_parameters()["maxsize"] == 512
assert asi._get_block_hash.cache_parameters()["maxsize"] == 512
Loading
Loading