diff --git a/.github/workflows/e2e-subtensor-tests.yaml b/.github/workflows/e2e-subtensor-tests.yaml index b83308ad12..cd603b5111 100644 --- a/.github/workflows/e2e-subtensor-tests.yaml +++ b/.github/workflows/e2e-subtensor-tests.yaml @@ -127,8 +127,7 @@ jobs: echo "Reading labels ..." if [[ "${GITHUB_EVENT_NAME}" == "pull_request" ]]; then - # Use GitHub CLI to read labels (works for forks too) - labels=$(gh pr view ${{ github.event.pull_request.number }} -R ${{ github.repository }} --json labels --jq '.labels[].name' || echo "") + labels=$(jq -r '.pull_request.labels[]?.name' "$GITHUB_EVENT_PATH" | tr '\n' ' ' || echo "") echo "Found labels: $labels" else labels="" diff --git a/bittensor/core/async_subtensor.py b/bittensor/core/async_subtensor.py index 3207e908e3..92e1518508 100644 --- a/bittensor/core/async_subtensor.py +++ b/bittensor/core/async_subtensor.py @@ -965,10 +965,14 @@ async def query_runtime_api( block_hash = await self.determine_block_hash(block, block_hash, reuse_block) if not block_hash and reuse_block: block_hash = self.substrate.last_block_hash - result = await self.substrate.runtime_call( - runtime_api, method, params, block_hash - ) - return result.value + return ( + await self.substrate.runtime_call( + api=runtime_api, + method=method, + params=params, + block_hash=block_hash, + ) + ).value async def query_subtensor( self, diff --git a/bittensor/core/subtensor.py b/bittensor/core/subtensor.py index 804d160bce..a17628a277 100644 --- a/bittensor/core/subtensor.py +++ b/bittensor/core/subtensor.py @@ -816,9 +816,13 @@ def query_runtime_api( """ block_hash = self.determine_block_hash(block) - result = self.substrate.runtime_call(runtime_api, method, params, block_hash) - return result.value + return self.substrate.runtime_call( + api=runtime_api, + method=method, + params=params, + block_hash=block_hash, + ).value def query_subtensor( self, diff --git a/bittensor/utils/retry.py b/bittensor/utils/retry.py new file mode 100644 index 0000000000..7da60ba98f --- /dev/null +++ b/bittensor/utils/retry.py @@ -0,0 +1,301 @@ +"""Retry utilities for handling transient failures with exponential backoff. + +This module provides optional retry wrappers for both synchronous and asynchronous +functions. Retry behavior is controlled via environment variables and is disabled +by default. + +Environment Variables: + BT_RETRY_ENABLED: Enable retry behavior ("true", "1", "yes", "on") + BT_RETRY_MAX_ATTEMPTS: Maximum retry attempts (default: 3) + BT_RETRY_BASE_DELAY: Base delay in seconds (default: 1.0) + BT_RETRY_MAX_DELAY: Maximum delay in seconds (default: 60.0) + BT_RETRY_BACKOFF_FACTOR: Exponential backoff multiplier (default: 2.0) + +Note: + This utility may be used internally by the SDK at outbound network boundaries + (e.g., Dendrite and Subtensor) and is also provided as an optional helper for + users who wish to implement consistent retry behavior. + +For more information on retry strategies, see: + https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ +""" + +import asyncio +import inspect +import os +import time +import random +import logging +from typing import Type, Tuple, Optional, Callable, Any, Union + +logger = logging.getLogger("bittensor.utils.retry") + + +# Helpers for runtime environment variable access +def _retry_enabled() -> bool: + return os.environ.get("BT_RETRY_ENABLED", "False").lower() in ( + "true", + "1", + "yes", + "on", + ) + + +def _retry_max_attempts() -> int: + """Get the maximum number of retry attempts from the environment, with validation.""" + default = 3 + raw = os.environ.get("BT_RETRY_MAX_ATTEMPTS") + if raw is None or raw == "": + return default + try: + value = int(raw) + if value <= 0: + logger.warning( + "Invalid value for BT_RETRY_MAX_ATTEMPTS=%r (must be positive); falling back to default %d", + raw, + default, + ) + return default + return value + except (TypeError, ValueError): + logger.warning( + "Invalid value for BT_RETRY_MAX_ATTEMPTS=%r; falling back to default %d", + raw, + default, + ) + return default + + +def _retry_base_delay() -> float: + """Get the base delay (in seconds) for retries from the environment, with validation.""" + default = 1.0 + raw = os.environ.get("BT_RETRY_BASE_DELAY") + if raw is None or raw == "": + return default + try: + value = float(raw) + if value < 0: + logger.warning( + "Invalid value for BT_RETRY_BASE_DELAY=%r (must be non-negative); falling back to default %.2f", + raw, + default, + ) + return default + return value + except (TypeError, ValueError): + logger.warning( + "Invalid value for BT_RETRY_BASE_DELAY=%r; falling back to default %.2f", + raw, + default, + ) + return default + + +def _retry_max_delay() -> float: + """Get the maximum delay (in seconds) for retries from the environment, with validation.""" + default = 60.0 + raw = os.environ.get("BT_RETRY_MAX_DELAY") + if raw is None or raw == "": + return default + try: + value = float(raw) + if value < 0: + logger.warning( + "Invalid value for BT_RETRY_MAX_DELAY=%r (must be non-negative); falling back to default %.2f", + raw, + default, + ) + return default + return value + except (TypeError, ValueError): + logger.warning( + "Invalid value for BT_RETRY_MAX_DELAY=%r; falling back to default %.2f", + raw, + default, + ) + return default + + +_RETRY_BACKOFF_FACTOR = 2.0 + + +def _retry_backoff_factor() -> float: + """Get the backoff factor for exponential backoff from the environment, with validation.""" + default = _RETRY_BACKOFF_FACTOR + raw = os.environ.get("BT_RETRY_BACKOFF_FACTOR") + if raw is None or raw == "": + return default + try: + value = float(raw) + if value <= 0: + logger.warning( + "Invalid value for BT_RETRY_BACKOFF_FACTOR=%r (must be positive); falling back to default %.2f", + raw, + default, + ) + return default + return value + except (TypeError, ValueError): + logger.warning( + "Invalid value for BT_RETRY_BACKOFF_FACTOR=%r (must be positive); falling back to default %.2f", + raw, + default, + ) + return default + + +def _get_backoff_time(attempt: int, base_delay: float, max_delay: float) -> float: + """Calculates backoff time with exponential backoff and jitter.""" + delay = min(max_delay, base_delay * (_retry_backoff_factor() ** attempt)) + # Add jitter while ensuring the final backoff does not exceed max_delay + return min(max_delay, delay * (0.5 + random.random())) + + +def retry_call( + func: Callable, + *args, + retry_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ( + OSError, + TimeoutError, + ), + max_attempts: Optional[int] = None, + base_delay: Optional[float] = None, + max_delay: Optional[float] = None, + **kwargs, +) -> Any: + """Synchronous retry wrapper with optional exponential backoff. + + Retries are only enabled when BT_RETRY_ENABLED is set to a truthy value. + When disabled, the function executes exactly once. + + Args: + func: The callable to be executed and potentially retried. + *args: Positional arguments forwarded to func. + retry_exceptions: Exception type(s) that trigger a retry. Any exception + not matching these types is raised immediately. Defaults to + (OSError, TimeoutError). + max_attempts: Maximum number of attempts. If None, uses + BT_RETRY_MAX_ATTEMPTS environment variable (default: 3). + base_delay: Base delay in seconds for exponential backoff. If None, + uses BT_RETRY_BASE_DELAY environment variable (default: 1.0). + max_delay: Maximum delay in seconds between attempts. If None, uses + BT_RETRY_MAX_DELAY environment variable (default: 60.0). + **kwargs: Keyword arguments forwarded to func. + + Returns: + The return value from the first successful func execution. + + Raises: + TypeError: If func is an async function. Use async_retry_call instead. + Exception: Any exception raised by func when retries are disabled, or + when the exception type doesn't match retry_exceptions, or after + all retry attempts are exhausted. + """ + # Validate that func is not async + if inspect.iscoroutinefunction(func): + raise TypeError( + f"retry_call() cannot be used with async functions. " + f"Use async_retry_call() instead for {func.__name__}." + ) + + if not _retry_enabled(): + return func(*args, **kwargs) + + # Resolve configuration + _max_attempts = max_attempts if max_attempts is not None else _retry_max_attempts() + _base_delay = base_delay if base_delay is not None else _retry_base_delay() + _max_delay = max_delay if max_delay is not None else _retry_max_delay() + + for attempt in range(1, _max_attempts + 1): + try: + return func(*args, **kwargs) + except retry_exceptions as e: + if attempt == _max_attempts: + logger.debug( + f"Retry exhausted after {_max_attempts} attempts. Last error: {e}" + ) + raise + + backoff = _get_backoff_time(attempt - 1, _base_delay, _max_delay) + logger.debug( + f"Retry attempt {attempt}/{_max_attempts} failed with {e}. Retrying in {backoff:.2f}s..." + ) + time.sleep(backoff) + + # This should never be reached due to the logic above + assert False, "Unreachable code" + + +async def async_retry_call( + func: Callable, + *args, + retry_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ( + OSError, + TimeoutError, + ), + max_attempts: Optional[int] = None, + base_delay: Optional[float] = None, + max_delay: Optional[float] = None, + **kwargs, +) -> Any: + """Asynchronous retry wrapper with optional exponential backoff. + + Retries are only enabled when BT_RETRY_ENABLED is set to a truthy value. + When disabled, the function executes exactly once. + + Args: + func: The async callable to be executed and potentially retried. + *args: Positional arguments forwarded to func on each attempt. + retry_exceptions: Exception type(s) that trigger a retry. Any exception + not matching these types is raised immediately. Defaults to + (OSError, TimeoutError). + max_attempts: Maximum number of attempts. If None, uses + BT_RETRY_MAX_ATTEMPTS environment variable (default: 3). + base_delay: Base delay in seconds for exponential backoff. If None, + uses BT_RETRY_BASE_DELAY environment variable (default: 1.0). + max_delay: Maximum delay in seconds between attempts. If None, uses + BT_RETRY_MAX_DELAY environment variable (default: 60.0). + **kwargs: Keyword arguments forwarded to func on each attempt. + + Returns: + The result from the first successful func execution. + + Raises: + TypeError: If func is not an async function. Use retry_call instead. + Exception: Any exception raised by func when retries are disabled, or + when the exception type doesn't match retry_exceptions, or after + all retry attempts are exhausted. + """ + # Validate that func is async + if not inspect.iscoroutinefunction(func): + raise TypeError( + f"async_retry_call() requires an async function. " + f"Use retry_call() instead for {func.__name__}." + ) + + if not _retry_enabled(): + return await func(*args, **kwargs) + + # Resolve configuration + _max_attempts = max_attempts if max_attempts is not None else _retry_max_attempts() + _base_delay = base_delay if base_delay is not None else _retry_base_delay() + _max_delay = max_delay if max_delay is not None else _retry_max_delay() + + for attempt in range(1, _max_attempts + 1): + try: + return await func(*args, **kwargs) + except retry_exceptions as e: + if attempt == _max_attempts: + logger.debug( + f"Retry exhausted after {_max_attempts} attempts. Last error: {e}" + ) + raise + + backoff = _get_backoff_time(attempt - 1, _base_delay, _max_delay) + logger.debug( + f"Retry attempt {attempt}/{_max_attempts} failed with {e}. Retrying in {backoff:.2f}s..." + ) + await asyncio.sleep(backoff) + + # This should never be reached due to the logic above + assert False, "Unreachable code" diff --git a/pyproject.toml b/pyproject.toml index 173aa491d6..df898fca1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,3 +94,14 @@ classifiers = [ [tool.setuptools] package-dir = {"bittensor" = "bittensor"} script-files = ["bittensor/utils/certifi.sh"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +filterwarnings = [ + "ignore::DeprecationWarning:pkg_resources.*:", +] +asyncio_default_fixture_loop_scope = "session" +addopts = "-s" diff --git a/tests/unit_tests/test_async_subtensor.py b/tests/unit_tests/test_async_subtensor.py index 0f83110729..f371f3b946 100644 --- a/tests/unit_tests/test_async_subtensor.py +++ b/tests/unit_tests/test_async_subtensor.py @@ -570,10 +570,10 @@ async def test_query_runtime_api(subtensor, mocker): # Asserts mocked_runtime_call.assert_called_once_with( - fake_runtime_api, - fake_method, - fake_params, - fake_block_hash, + api=fake_runtime_api, + method=fake_method, + params=fake_params, + block_hash=fake_block_hash, ) assert result == mocked_runtime_call.return_value.value @@ -1068,13 +1068,11 @@ async def test_get_neuron_for_pubkey_and_subnet_success(subtensor, mocker): ) subtensor.substrate.runtime_call.assert_awaited_once() subtensor.substrate.runtime_call.assert_called_once_with( - "NeuronInfoRuntimeApi", - "get_neuron", - [fake_netuid, fake_uid.value], - None, + api="NeuronInfoRuntimeApi", + method="get_neuron", + params=[fake_netuid, fake_uid.value], + block_hash=None, ) - mocked_neuron_info.assert_called_once_with(fake_result) - assert result == "fake_neuron_info" @pytest.mark.asyncio @@ -1146,10 +1144,10 @@ async def test_get_neuron_for_pubkey_and_subnet_rpc_result_empty(subtensor, mock reuse_block_hash=False, ) subtensor.substrate.runtime_call.assert_called_once_with( - "NeuronInfoRuntimeApi", - "get_neuron", - [fake_netuid, fake_uid], - None, + api="NeuronInfoRuntimeApi", + method="get_neuron", + params=[fake_netuid, fake_uid], + block_hash=None, ) mocked_get_null_neuron.assert_called_once() assert result == "null_neuron" @@ -1350,10 +1348,10 @@ async def test_get_delegated_with_empty_result(subtensor, mocker): # Asserts mocked_runtime_call.assert_called_once_with( - "DelegateInfoRuntimeApi", - "get_delegated", - [fake_coldkey_ss58], - None, + api="DelegateInfoRuntimeApi", + method="get_delegated", + params=[fake_coldkey_ss58], + block_hash=None, ) assert result == [] diff --git a/tests/unit_tests/utils/test_retry.py b/tests/unit_tests/utils/test_retry.py new file mode 100644 index 0000000000..30a8d673ab --- /dev/null +++ b/tests/unit_tests/utils/test_retry.py @@ -0,0 +1,110 @@ +import pytest +from unittest.mock import Mock, patch, AsyncMock +from bittensor.utils.retry import retry_call, async_retry_call + +# Create custom exception for testing +class NetworkError(Exception): + pass + +class NonRetryableError(Exception): + pass + +@pytest.fixture +def mock_sleep(): + with patch("time.sleep") as m: + yield m + +@pytest.fixture +def mock_async_sleep(): + with patch("asyncio.sleep", new_callable=AsyncMock) as m: + yield m + +@pytest.fixture +def enable_retries(): + # Patch environment variables + with patch.dict("os.environ", {"BT_RETRY_ENABLED": "True"}): + yield + +@pytest.fixture +def disable_retries(): + # Patch environment variables + with patch.dict("os.environ", {"BT_RETRY_ENABLED": "False"}): + yield + +# --- Sync Tests --- + +def test_sync_retry_success(enable_retries): + mock_func = Mock(return_value="success") + result = retry_call(mock_func, retry_exceptions=(NetworkError,), max_attempts=3) + assert result == "success" + assert mock_func.call_count == 1 + +def test_sync_retry_eventual_success(enable_retries, mock_sleep): + mock_func = Mock(side_effect=[NetworkError("Fail 1"), NetworkError("Fail 2"), "success"]) + result = retry_call(mock_func, retry_exceptions=(NetworkError,), max_attempts=3) + assert result == "success" + assert mock_func.call_count == 3 + +def test_sync_retry_exhaustion(enable_retries, mock_sleep): + mock_func = Mock(side_effect=NetworkError("Persistent Fail")) + with pytest.raises(NetworkError, match="Persistent Fail"): + retry_call(mock_func, retry_exceptions=(NetworkError,), max_attempts=3) + assert mock_func.call_count == 3 + +def test_sync_no_retry_on_wrong_exception(enable_retries): + mock_func = Mock(side_effect=NonRetryableError("Fatal")) + with pytest.raises(NonRetryableError, match="Fatal"): + retry_call(mock_func, retry_exceptions=(NetworkError,), max_attempts=3) + assert mock_func.call_count == 1 + +def test_sync_disabled_retries_executes_once(disable_retries): + mock_func = Mock(side_effect=NetworkError("Fail")) + with pytest.raises(NetworkError, match="Fail"): + retry_call(mock_func, retry_exceptions=(NetworkError,), max_attempts=3) + assert mock_func.call_count == 1 + +def test_sync_default_retry_exceptions_do_not_retry_non_network_error(enable_retries): + mock_func = Mock(side_effect=ValueError("bad input")) + with pytest.raises(ValueError, match="bad input"): + # Should raise immediately because ValueError is not in (OSError, TimeoutError) + retry_call(mock_func) + assert mock_func.call_count == 1 + + +# --- Async Tests --- + +@pytest.mark.asyncio +async def test_async_retry_success(enable_retries): + mock_func = AsyncMock(return_value="success") + result = await async_retry_call(mock_func, retry_exceptions=(NetworkError,), max_attempts=3) + assert result == "success" + assert mock_func.call_count == 1 + +@pytest.mark.asyncio +async def test_async_retry_eventual_success(enable_retries, mock_async_sleep): + mock_func = AsyncMock(side_effect=[NetworkError("Fail 1"), NetworkError("Fail 2"), "success"]) + result = await async_retry_call(mock_func, retry_exceptions=(NetworkError,), max_attempts=3) + assert result == "success" + assert mock_func.call_count == 3 + +@pytest.mark.asyncio +async def test_async_retry_exhaustion(enable_retries, mock_async_sleep): + mock_func = AsyncMock(side_effect=NetworkError("Persistent Fail")) + with pytest.raises(NetworkError, match="Persistent Fail"): + await async_retry_call(mock_func, retry_exceptions=(NetworkError,), max_attempts=3) + assert mock_func.call_count == 3 + +@pytest.mark.asyncio +async def test_async_no_retry_on_wrong_exception(enable_retries): + mock_func = AsyncMock(side_effect=NonRetryableError("Fatal")) + with pytest.raises(NonRetryableError, match="Fatal"): + await async_retry_call(mock_func, retry_exceptions=(NetworkError,), max_attempts=3) + assert mock_func.call_count == 1 + +@pytest.mark.asyncio +async def test_async_disabled_retries_executes_once(disable_retries): + mock_func = AsyncMock(side_effect=NetworkError("Fail")) + with pytest.raises(NetworkError, match="Fail"): + await async_retry_call(mock_func, retry_exceptions=(NetworkError,), max_attempts=3) + assert mock_func.call_count == 1 +