Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
AWS_ENDPOINT_URL=http://localhost:4566
AWS_ACCESS_KEY_ID=test
AWS_SECRET_ACCESS_KEY=test
AWS_DEFAULT_REGION=us-east-1
16 changes: 15 additions & 1 deletion .github/workflows/code-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,28 @@ permissions:

env:
AWS_DEFAULT_REGION: "us-east-1"
AWS_ENDPOINT_URL: "http://localhost:4566"
AWS_ACCESS_KEY_ID: "test"
AWS_SECRET_ACCESS_KEY: "test" # pragma: allowlist secret

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest] # eventually add `windows-latest`
os: [ubuntu-latest] # eventually add `windows-latest` and `macos-latest`
python-version: ["3.10", "3.11", "3.12"]
services:
ministack:
image: ministackorg/ministack:1.3.53
ports:
- 4566:4566
env:
AWS_DEFAULT_REGION: us-east-1
GATEWAY_PORT: 4566
MINISTACK_ACCOUNT_ID: "000000000000"
MINISTACK_REGION: us-east-1
LOG_LEVEL: INFO
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
Expand Down
43 changes: 43 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
services:
ministack:
image: ministackorg/ministack:1.3.53
container_name: taskiq_sqs_ministack
ports:
- "4566:4566"
environment:
AWS_DEFAULT_REGION: us-east-1
GATEWAY_PORT: 4566
MINISTACK_ACCOUNT_ID: "000000000000"
MINISTACK_REGION: us-east-1
LOG_LEVEL: INFO
PERSIST_STATE: "1"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:4566/_ministack/health"]
interval: 2s
timeout: 5s
retries: 30
start_period: 5s
volumes:
- ministack-data:/tmp/ministack-state
networks:
- taskiq-sqs-network

redis:
image: bitnamilegacy/redis:7.4.2
environment:
ALLOW_EMPTY_PASSWORD: "yes" # pragma: allowlist secret
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 5s
retries: 3
start_period: 10s
ports:
- 6379:6379

networks:
taskiq-sqs-network:
driver: bridge

volumes:
ministack-data:
46 changes: 46 additions & 0 deletions examples/example_broker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Run worker:
taskiq worker examples.example_broker:broker

Run broker to send a task:
python examples/example_broker.py
"""

import asyncio

import boto3
import dotenv
from taskiq_redis import RedisAsyncResultBackend

from taskiq_sqs import SQSBroker


dotenv.load_dotenv()

QUEUE_NAME = "my-queue"
QUEUE_URL = f"http://localhost:4566/000000000000/{QUEUE_NAME}"


boto3.client("sqs").create_queue(QueueName=QUEUE_NAME)

broker = SQSBroker(QUEUE_URL, sqs_region_override="us-east-1").with_result_backend(
RedisAsyncResultBackend(redis_url="redis://localhost:6379")
)


@broker.task()
async def i_love_aws() -> None:
"""I hope my cloud bill doesn't get too high!"""
await asyncio.sleep(5.5)
print("Hello there!")


async def main() -> None:
await broker.startup()
task = await i_love_aws.kiq()
print(await task.wait_result())
await broker.shutdown()


if __name__ == "__main__":
asyncio.run(main())
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ dev = [
test = [
"pytest>=9.0.3",
"pytest-asyncio>=0.23.8",
"requests>=2.34.2",
]
lint = [
"bandit>=1.9.4",
Expand All @@ -60,9 +59,12 @@ lint = [
types = [
"mypy>=2.1.0",
"mypy-boto3-sqs>=1.34.101",
"types-requests>=2.33.0.20260518",
"boto3-stubs[essential]>=1.34.84",
]
examples = [
"python-dotenv>=1.2.2",
"taskiq-redis>=1.2.2",
]


[build-system]
Expand Down Expand Up @@ -138,6 +140,11 @@ ignore = [
"S101", # assert usage
"SLF001", # private member accessed
]
"examples/*" = [
"T201", # print
"D",
"INP001",
]

[tool.ruff.lint.pydocstyle]
convention = "google"
Expand Down
29 changes: 12 additions & 17 deletions src/taskiq_sqs/broker.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
from __future__ import (
annotations, # Needed for conditional type import support
)

import asyncio
import logging
from collections import defaultdict
from collections.abc import AsyncGenerator, Callable, Mapping
from datetime import datetime, timezone
from typing import TYPE_CHECKING

import boto3
from asyncer import asyncify
from botocore.exceptions import ClientError
from taskiq import AsyncBroker
from taskiq.abc.result_backend import AsyncResultBackend
from taskiq.acks import AckableMessage
from taskiq.message import BrokerMessage

from taskiq_sqs.aws import get_container_credentials


if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Callable, Mapping

from mypy_boto3_sqs.service_resource import Queue, SQSServiceResource
from taskiq.abc.result_backend import AsyncResultBackend
from taskiq.acks import AckableMessage
from taskiq.message import BrokerMessage


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,7 +47,7 @@ def __init__( # noqa: D107
super().__init__(result_backend, task_id_generator)

if not sqs_queue_url or not sqs_queue_url.startswith("http"):
raise SQSBrokerError("A valid SQS Queue URL is required")
raise SQSBrokerError("A valid SQS queue url is required")

# NOTE: This bypasses the normal order of operations for boto3 auth and
# goes straight to using the ECS role creds from the metadata
Expand All @@ -62,6 +58,7 @@ def __init__( # noqa: D107
self.sqs_queue_url = sqs_queue_url
self._sqs: SQSServiceResource | None = None
self._sqs_queue: Queue | None = None
self._creds_expiration: datetime | None = None

if max_number_of_messages > 10: # noqa: PLR2004
raise SQSBrokerError("MaxNumberOfMessages can be no greater than 10")
Expand All @@ -70,12 +67,10 @@ def __init__( # noqa: D107
self.max_number_of_messages = max(max_number_of_messages, 1)

@property
def _sqs_credentials_expired(self) -> datetime | bool:
return self._creds_expiration and self._creds_expiration < datetime.now(
tz=timezone.utc,
)
def _sqs_credentials_expired(self) -> datetime | bool | None:
return self._creds_expiration and self._creds_expiration < datetime.now(tz=timezone.utc)

async def _sqs_client(self) -> SQSServiceResource:
async def _sqs_client(self) -> "SQSServiceResource":
if self._sqs and not self._sqs_credentials_expired:
return self._sqs

Expand All @@ -95,7 +90,7 @@ async def _sqs_client(self) -> SQSServiceResource:
aws_session_token=creds.get("Token"),
)

async def _get_queue(self) -> Queue:
async def _get_queue(self) -> "Queue":
if self._sqs_queue and not self._sqs_credentials_expired:
return self._sqs_queue

Expand All @@ -105,7 +100,7 @@ async def _get_queue(self) -> Queue:
)

if not self._sqs_queue:
exc_message = "SQS Queue not found"
exc_message = "SQS queue not found"
raise Exception(exc_message) # noqa: TRY002

return self._sqs_queue
Expand Down
16 changes: 7 additions & 9 deletions tests/test_broker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from taskiq_sqs import SQSBroker


def test_init() -> None:
broker = SQSBroker("https://sqs.us-west-2.amazonaws.com/123456789012/queue-name")
assert (
broker.sqs_queue_url
== "https://sqs.us-west-2.amazonaws.com/123456789012/queue-name"
)
assert broker.force_ecs_container_credentials is False
assert broker.sqs_region_override is None
assert broker._sqs_queue is None
class TestInitParameters:
async def test_initialization_logic(self) -> None:
broker = SQSBroker("http://localhost:4566/000000000000/my-queue")
assert broker.sqs_queue_url == "http://localhost:4566/000000000000/my-queue"
assert broker.force_ecs_container_credentials is False
assert broker.sqs_region_override is None
assert broker._sqs_queue is None
42 changes: 42 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.