Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,18 @@ async def retrieve_next_page(self, start_key) -> list:
self.last_key = result.last_key
return result.records

async def retrieve_all_records(self) -> list[Any]:
"""
Retrieves all records from all subsequent pages for the AsyncQueryMapResult,
returning them as a list.

Side effect:
The self.records list will be populated fully after running this method.
"""
async for _ in self:
pass
return self.records

def __aiter__(self):
return self

Expand Down Expand Up @@ -558,6 +570,7 @@ async def __anext__(self):
self.loading_complete = True
raise StopAsyncIteration

self.records.extend(next_page)
# Update the buffer with the newly fetched records
self._buffer = iter(next_page)
return next(self._buffer)
Expand Down
13 changes: 13 additions & 0 deletions async_substrate_interface/sync_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,18 @@ def retrieve_next_page(self, start_key) -> list:
self.last_key = result.last_key
return result.records

def retrieve_all_records(self) -> list[Any]:
"""
Retrieves all records from all subsequent pages for the QueryMapResult,
returning them as a list.

Side effect:
The self.records list will be populated fully after running this method.
"""
for _ in self:
pass
return self.records

def __iter__(self):
return self

Expand Down Expand Up @@ -511,6 +523,7 @@ def __next__(self):
self.loading_complete = True
raise StopIteration

self.records.extend(next_page)
# Update the buffer with the newly fetched records
self._buffer = iter(next_page)
return next(self._buffer)
Expand Down
42 changes: 42 additions & 0 deletions tests/unit_tests/asyncio_/test_substrate_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from websockets.protocol import State

from async_substrate_interface.async_substrate import (
AsyncQueryMapResult,
AsyncSubstrateInterface,
get_async_substrate_interface,
)
Expand Down Expand Up @@ -177,6 +178,47 @@ async def test_memory_leak():
)


@pytest.mark.asyncio
async def test_async_query_map_result_retrieve_all_records():
"""Test that retrieve_all_records fetches all pages and returns the full record list."""
page1 = [("key1", "val1"), ("key2", "val2")]
page2 = [("key3", "val3"), ("key4", "val4")]
page3 = [("key5", "val5")] # partial page signals loading_complete

mock_substrate = MagicMock()

qm = AsyncQueryMapResult(
records=list(page1),
page_size=2,
substrate=mock_substrate,
module="TestModule",
storage_function="TestStorage",
last_key="key2",
)

# Build mock pages: first call returns page2 (full page), second returns page3 (partial)
page2_result = AsyncQueryMapResult(
records=list(page2),
page_size=2,
substrate=mock_substrate,
last_key="key4",
)
page3_result = AsyncQueryMapResult(
records=list(page3),
page_size=2,
substrate=mock_substrate,
last_key="key5",
)
mock_substrate.query_map = AsyncMock(side_effect=[page2_result, page3_result])

result = await qm.retrieve_all_records()

assert result == page1 + page2 + page3
assert qm.records == page1 + page2 + page3
assert qm.loading_complete is True
assert mock_substrate.query_map.call_count == 2


class TestGetBlockHash:
@pytest.fixture
def substrate(self):
Expand Down
42 changes: 41 additions & 1 deletion tests/unit_tests/sync/test_substrate_interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tracemalloc
from unittest.mock import MagicMock

from async_substrate_interface.sync_substrate import SubstrateInterface
from async_substrate_interface.sync_substrate import SubstrateInterface, QueryMapResult
from async_substrate_interface.types import ScaleObj

from tests.helpers.settings import ARCHIVE_ENTRYPOINT, LATENT_LITE_ENTRYPOINT
Expand Down Expand Up @@ -124,6 +124,46 @@ def test_memory_leak():
)


def test_async_query_map_result_retrieve_all_records():
"""Test that retrieve_all_records fetches all pages and returns the full record list."""
page1 = [("key1", "val1"), ("key2", "val2")]
page2 = [("key3", "val3"), ("key4", "val4")]
page3 = [("key5", "val5")] # partial page signals loading_complete

mock_substrate = MagicMock()

qm = QueryMapResult(
records=list(page1),
page_size=2,
substrate=mock_substrate,
module="TestModule",
storage_function="TestStorage",
last_key="key2",
)

# Build mock pages: first call returns page2 (full page), second returns page3 (partial)
page2_result = QueryMapResult(
records=list(page2),
page_size=2,
substrate=mock_substrate,
last_key="key4",
)
page3_result = QueryMapResult(
records=list(page3),
page_size=2,
substrate=mock_substrate,
last_key="key5",
)
mock_substrate.query_map = MagicMock(side_effect=[page2_result, page3_result])

result = qm.retrieve_all_records()

assert result == page1 + page2 + page3
assert qm.records == page1 + page2 + page3
assert qm.loading_complete is True
assert mock_substrate.query_map.call_count == 2


class TestGetBlockHash:
def _make_substrate(self):
s = SubstrateInterface("ws://localhost", _mock=True)
Expand Down
Loading