Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from scalecodec.types import (
GenericCall,
GenericExtrinsic,
GenericRuntimeCallDefinition,
ss58_encode,
MultiAccountId,
)
Expand Down Expand Up @@ -74,12 +73,10 @@
_bt_decode_to_dict_or_list,
legacy_scale_decode,
convert_account_ids,
decode_query_map_async,
)
from async_substrate_interface.utils.storage import StorageKey
from async_substrate_interface.type_registry import _TYPE_REGISTRY
from async_substrate_interface.utils.decoding import (
decode_query_map,
)

ResultHandler = Callable[[dict, Any], Awaitable[tuple[dict, bool]]]

Expand Down Expand Up @@ -1421,7 +1418,9 @@ async def decode_scale(
if runtime is None:
runtime = await self.init_runtime(block_hash=block_hash)
if runtime.metadata_v15 is not None and force_legacy is False:
obj = decode_by_type_string(type_string, runtime.registry, scale_bytes)
obj = await asyncio.to_thread(
decode_by_type_string, type_string, runtime.registry, scale_bytes
)
if self.decode_ss58:
try:
type_str_int = int(type_string.split("::")[1])
Expand Down Expand Up @@ -3880,18 +3879,20 @@ async def query_map(
params=[result_keys, block_hash],
runtime=runtime,
)
changes = []
for result_group in response["result"]:
result = decode_query_map(
result_group["changes"],
prefix,
runtime,
param_types,
params,
value_type,
key_hashers,
ignore_decoding_errors,
self.decode_ss58,
)
changes.extend(result_group["changes"])
result = await decode_query_map_async(
changes,
prefix,
runtime,
param_types,
params,
value_type,
key_hashers,
ignore_decoding_errors,
self.decode_ss58,
)
else:
# storage item and value scale type are not included here because this is batch-decoded in rust
page_batches = [
Expand All @@ -3909,8 +3910,8 @@ async def query_map(
results: RequestResults = await self._make_rpc_request(
payloads, runtime=runtime
)
for result in results.values():
res = result[0]
for result_ in results.values():
res = result_[0]
if "error" in res:
err_msg = res["error"]["message"]
if (
Expand All @@ -3928,7 +3929,7 @@ async def query_map(
else:
for result_group in res["result"]:
changes.extend(result_group["changes"])
result = decode_query_map(
result = await decode_query_map_async(
changes,
prefix,
runtime,
Expand Down
132 changes: 123 additions & 9 deletions async_substrate_interface/utils/decoding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Union, TYPE_CHECKING, Any

from bt_decode import AxonInfo, PrometheusInfo, decode_list
Expand Down Expand Up @@ -72,16 +73,34 @@ def _decode_scale_list_with_runtime(
return obj


def decode_query_map(
async def _async_decode_scale_list_with_runtime(
type_strings: list[str],
scale_bytes_list: list[bytes],
runtime: "Runtime",
return_scale_obj: bool = False,
):
if runtime.metadata_v15 is not None:
obj = await asyncio.to_thread(
decode_list, type_strings, runtime.registry, scale_bytes_list
)
else:
obj = [
legacy_scale_decode(x, y, runtime)
for (x, y) in zip(type_strings, scale_bytes_list)
]
if return_scale_obj:
return [ScaleObj(x) for x in obj]
else:
return obj


def _decode_query_map_pre(
result_group_changes: list,
prefix,
runtime: "Runtime",
param_types,
params,
value_type,
key_hashers,
ignore_decoding_errors,
decode_ss58: bool = False,
):
def concat_hash_len(key_hasher: str) -> int:
"""
Expand All @@ -98,7 +117,6 @@ def concat_hash_len(key_hasher: str) -> int:

hex_to_bytes_ = hex_to_bytes

result = []
# Determine type string
key_type_string_ = []
for n in range(len(params), len(param_types)):
Expand All @@ -116,11 +134,25 @@ def concat_hash_len(key_hasher: str) -> int:
pre_decoded_values.append(
hex_to_bytes_(item[1]) if item[1] is not None else b""
)
all_decoded = _decode_scale_list_with_runtime(
pre_decoded_key_types + pre_decoded_value_types,
pre_decoded_keys + pre_decoded_values,
runtime,
return (
pre_decoded_key_types,
pre_decoded_value_types,
pre_decoded_keys,
pre_decoded_values,
)


def _decode_query_map_post(
pre_decoded_key_types,
pre_decoded_value_types,
all_decoded,
runtime: "Runtime",
param_types,
params,
ignore_decoding_errors,
decode_ss58: bool = False,
):
result = []
middl_index = len(all_decoded) // 2
decoded_keys = all_decoded[:middl_index]
decoded_values = all_decoded[middl_index:]
Expand Down Expand Up @@ -167,6 +199,88 @@ def concat_hash_len(key_hasher: str) -> int:
return result


async def decode_query_map_async(
result_group_changes: list,
prefix,
runtime: "Runtime",
param_types,
params,
value_type,
key_hashers,
ignore_decoding_errors,
decode_ss58: bool = False,
):
(
pre_decoded_key_types,
pre_decoded_value_types,
pre_decoded_keys,
pre_decoded_values,
) = _decode_query_map_pre(
result_group_changes,
prefix,
param_types,
params,
value_type,
key_hashers,
)
all_decoded = await _async_decode_scale_list_with_runtime(
pre_decoded_key_types + pre_decoded_value_types,
pre_decoded_keys + pre_decoded_values,
runtime,
)
return _decode_query_map_post(
pre_decoded_key_types,
pre_decoded_value_types,
all_decoded,
runtime,
param_types,
params,
ignore_decoding_errors,
decode_ss58=decode_ss58,
)


def decode_query_map(
result_group_changes: list,
prefix,
runtime: "Runtime",
param_types,
params,
value_type,
key_hashers,
ignore_decoding_errors,
decode_ss58: bool = False,
):
(
pre_decoded_key_types,
pre_decoded_value_types,
pre_decoded_keys,
pre_decoded_values,
) = _decode_query_map_pre(
result_group_changes,
prefix,
param_types,
params,
value_type,
key_hashers,
)
all_decoded = _decode_scale_list_with_runtime(
pre_decoded_key_types + pre_decoded_value_types,
pre_decoded_keys + pre_decoded_values,
runtime,
)
return _decode_query_map_post(
pre_decoded_key_types,
pre_decoded_value_types,
all_decoded,
runtime,
param_types,
params,
ignore_decoding_errors,
decode_ss58=decode_ss58,
)


def legacy_scale_decode(
type_string: str, scale_bytes: Union[str, bytes, ScaleBytes], runtime: "Runtime"
):
Expand Down
114 changes: 114 additions & 0 deletions tests/benchmarks/benchmark_to_thread_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
Results:

93 items

original (not threading) decoding:
median 3.9731219584937207
mean 3.810443129093619
stdev 0.9819147187144933

to_thread decoding:
median 2.7345559374953154
mean 2.7784998625924344
stdev 0.11112115146834547

"""

import asyncio

from scalecodec import ss58_encode

from async_substrate_interface.async_substrate import (
AsyncSubstrateInterface,
AsyncQueryMapResult,
)
from tests.helpers.settings import LATENT_LITE_ENTRYPOINT


async def benchmark_to_thread_decoding():
async def _query_alpha(hk_: str, sem: asyncio.Semaphore) -> list:
try:
async with sem:
results = []
qm: AsyncQueryMapResult = await substrate.query_map(
"SubtensorModule",
"Alpha",
params=[hk_],
block_hash=block_hash,
fully_exhaust=False,
page_size=100,
)
async for result in qm:
results.append(result)
return results
except Exception as e:
raise type(e)(f"[hotkey={hk_}] {e}") from e

loop = asyncio.get_running_loop()
async with AsyncSubstrateInterface(
LATENT_LITE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor"
) as substrate:
block_hash = (
"0xb0f4a6fb95279f035f145600590e6d5508edea986c2e703e16b6bfbe08f29dbd"
)
start = loop.time()
total_hotkey_alpha_q, total_hotkey_shares_q = await asyncio.gather(
substrate.query_map(
"SubtensorModule",
"TotalHotkeyAlpha",
block_hash=block_hash,
page_size=100,
fully_exhaust=False,
params=[],
),
substrate.query_map(
"SubtensorModule",
"TotalHotkeyShares",
block_hash=block_hash,
fully_exhaust=False,
page_size=100,
params=[],
),
)
hotkeys = set()
tasks: list[asyncio.Task] = []
sema4 = asyncio.Semaphore(100)
for (hk, netuid), alpha in total_hotkey_alpha_q.records:
hotkey = ss58_encode(bytes(hk[0]), 42)
if alpha.value > 0:
if hotkey not in hotkeys:
hotkeys.add(hotkey)
tasks.append(
loop.create_task(_query_alpha(hotkey, sema4), name=hotkey)
)
for (hk, netuid), alpha_bits in total_hotkey_shares_q.records:
hotkey = ss58_encode(bytes(hk[0]), 42)
alpha_bits_value = alpha_bits.value["bits"]
if alpha_bits_value > 0:
if hotkey not in hotkeys:
hotkeys.add(hotkey)
tasks.append(
loop.create_task(_query_alpha(hotkey, sema4), name=hotkey)
)
await asyncio.gather(*tasks)
end = loop.time()
return len(tasks), end - start


if __name__ == "__main__":
results = []
for _ in range(10):
len_tasks, time = asyncio.run(benchmark_to_thread_decoding())
results.append((len_tasks, time))

for len_tasks, time in results:
if len_tasks != 910:
print(len_tasks, time)
time_results = [x[1] for x in results]
import statistics

median = statistics.median(time_results)
mean = statistics.mean(time_results)
stdev = statistics.stdev(time_results)
print(median, mean, stdev)
Loading