Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install Hatch
run: |
python -m pip install --upgrade hatch
python -m pip install hatch==1.15.0
- name: static analysis
run: hatch fmt --check
- name: type checking
Expand Down
20 changes: 16 additions & 4 deletions src/aws_durable_execution_sdk_python/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,23 +464,35 @@ class JitterStrategy(StrEnum):

Jitter is meant to be used to spread operations across time.

Based on AWS Architecture Blog: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/

members:
:NONE: No jitter; use the exact calculated delay
:FULL: Full jitter; random delay between 0 and calculated delay
:HALF: Half jitter; random delay between 0.5x and 1.0x of the calculated delay
:HALF: Equal jitter; random delay between 0.5x and 1.0x of the calculated delay
"""

NONE = "NONE"
FULL = "FULL"
HALF = "HALF"

def compute_jitter(self, delay) -> float:
def apply_jitter(self, delay: float) -> float:
"""Apply jitter to a delay value and return the final delay.

Args:
delay: The base delay value to apply jitter to

Returns:
The final delay after applying jitter strategy
"""
match self:
case JitterStrategy.NONE:
return 0
return delay
case JitterStrategy.HALF:
return delay * (random.random() * 0.5 + 0.5) # noqa: S311
# Equal jitter: delay/2 + random(0, delay/2)
return delay / 2 + random.random() * (delay / 2) # noqa: S311
case _: # default is FULL
# Full jitter: random(0, delay)
return random.random() * delay # noqa: S311


Expand Down
42 changes: 28 additions & 14 deletions src/aws_durable_execution_sdk_python/retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

Numeric = int | float

# Default pattern that matches all error messages
_DEFAULT_RETRYABLE_ERROR_PATTERN = re.compile(r".*")


@dataclass
class RetryDecision:
Expand Down Expand Up @@ -47,10 +50,8 @@ class RetryStrategyConfig:
) # 5 minutes
backoff_rate: Numeric = 2.0
jitter_strategy: JitterStrategy = field(default=JitterStrategy.FULL)
retryable_errors: list[str | re.Pattern] = field(
default_factory=lambda: [re.compile(r".*")]
)
retryable_error_types: list[type[Exception]] = field(default_factory=list)
retryable_errors: list[str | re.Pattern] | None = None
retryable_error_types: list[type[Exception]] | None = None

@property
def initial_delay_seconds(self) -> int:
Expand All @@ -64,42 +65,55 @@ def max_delay_seconds(self) -> int:


def create_retry_strategy(
config: RetryStrategyConfig,
config: RetryStrategyConfig | None = None,
) -> Callable[[Exception, int], RetryDecision]:
if config is None:
config = RetryStrategyConfig()

# Apply default retryableErrors only if user didn't specify either filter
should_use_default_errors: bool = (
config.retryable_errors is None and config.retryable_error_types is None
)

retryable_errors: list[str | re.Pattern] = (
config.retryable_errors
if config.retryable_errors is not None
else ([_DEFAULT_RETRYABLE_ERROR_PATTERN] if should_use_default_errors else [])
)
retryable_error_types: list[type[Exception]] = config.retryable_error_types or []

def retry_strategy(error: Exception, attempts_made: int) -> RetryDecision:
# Check if we've exceeded max attempts
if attempts_made >= config.max_attempts:
return RetryDecision.no_retry()

# Check if error is retryable based on error message
is_retryable_error_message = any(
is_retryable_error_message: bool = any(
pattern.search(str(error))
if isinstance(pattern, re.Pattern)
else pattern in str(error)
for pattern in config.retryable_errors
for pattern in retryable_errors
)

# Check if error is retryable based on error type
is_retryable_error_type = any(
isinstance(error, error_type) for error_type in config.retryable_error_types
is_retryable_error_type: bool = any(
isinstance(error, error_type) for error_type in retryable_error_types
)

if not is_retryable_error_message and not is_retryable_error_type:
return RetryDecision.no_retry()

# Calculate delay with exponential backoff
delay = min(
base_delay: float = min(
config.initial_delay_seconds * (config.backoff_rate ** (attempts_made - 1)),
config.max_delay_seconds,
)
delay_with_jitter = delay + config.jitter_strategy.compute_jitter(delay)
delay_with_jitter = math.ceil(delay_with_jitter)
final_delay = max(1, delay_with_jitter)
# Apply jitter to get final delay
delay_with_jitter: float = config.jitter_strategy.apply_jitter(base_delay)
# Round up and ensure minimum of 1 second
final_delay: int = max(1, math.ceil(delay_with_jitter))

return RetryDecision.retry(Duration(seconds=round(final_delay)))
return RetryDecision.retry(Duration(seconds=final_delay))

return retry_strategy

Expand Down
12 changes: 6 additions & 6 deletions src/aws_durable_execution_sdk_python/waits.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import math
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generic

Expand Down Expand Up @@ -81,17 +82,16 @@ def wait_strategy(result: T, attempts_made: int) -> WaitDecision:
return WaitDecision.no_wait()

# Calculate delay with exponential backoff
base_delay = min(
base_delay: float = min(
config.initial_delay_seconds * (config.backoff_rate ** (attempts_made - 1)),
config.max_delay_seconds,
)

# Apply jitter (add jitter to base delay)
jitter = config.jitter_strategy.compute_jitter(base_delay)
delay_with_jitter = base_delay + jitter
# Apply jitter to get final delay
delay_with_jitter: float = config.jitter_strategy.apply_jitter(base_delay)

# Ensure delay is an integer >= 1
final_delay = max(1, round(delay_with_jitter))
# Round up and ensure minimum of 1 second
final_delay: int = max(1, math.ceil(delay_with_jitter))

return WaitDecision.wait(Duration(seconds=final_delay))

Expand Down
Loading
Loading