Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ ignore = [
# Conflicted rules
"D203", # with D211
"D212", # with D213
"COM812", # with formatter
]

[tool.ruff.lint.per-file-ignores]
Expand All @@ -146,6 +147,8 @@ ignore = [
"S608",

"RUF",

"PLR2004", # magic numbers in tests
]
"tests/test_linting.py" = [
"S603", # subprocess usage
Expand Down
10 changes: 10 additions & 0 deletions src/taskiq_pg/_internal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from taskiq_pg._internal.broker import BasePostgresBroker
from taskiq_pg._internal.result_backend import BasePostgresResultBackend
from taskiq_pg._internal.schedule_source import BasePostgresScheduleSource


__all__ = [
"BasePostgresBroker",
"BasePostgresResultBackend",
"BasePostgresScheduleSource",
]
49 changes: 49 additions & 0 deletions src/taskiq_pg/_internal/schedule_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

import typing as tp

from taskiq import ScheduleSource


if tp.TYPE_CHECKING:
from taskiq.abc.broker import AsyncBroker


class BasePostgresScheduleSource(ScheduleSource):
def __init__(
self,
broker: AsyncBroker,
dsn: str | tp.Callable[[], str] = "postgresql://postgres:postgres@localhost:5432/postgres",
table_name: str = "taskiq_schedules",
**connect_kwargs: tp.Any,
) -> None:
"""
Initialize the PostgreSQL scheduler source.

Sets up a scheduler source that stores scheduled tasks in a PostgreSQL database.
This scheduler source manages task schedules, allowing for persistent storage and retrieval of scheduled tasks
across application restarts.

Args:
dsn: PostgreSQL connection string
table_name: Name of the table to store scheduled tasks. Will be created automatically if it doesn't exist.
broker: The TaskIQ broker instance to use for finding and managing tasks.
Required if startup_schedule is provided.
**connect_kwargs: Additional keyword arguments passed to the database connection pool.

"""
self._broker: tp.Final = broker
self._dsn: tp.Final = dsn
self._table_name: tp.Final = table_name
self._connect_kwargs: tp.Final = connect_kwargs

@property
def dsn(self) -> str | None:
"""
Get the DSN string.

Returns the DSN string or None if not set.
"""
if callable(self._dsn):
return self._dsn()
return self._dsn
2 changes: 2 additions & 0 deletions src/taskiq_pg/aiopg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from taskiq_pg.aiopg.result_backend import AiopgResultBackend
from taskiq_pg.aiopg.schedule_source import AiopgScheduleSource


__all__ = [
"AiopgResultBackend",
"AiopgScheduleSource",
]
Empty file added src/taskiq_pg/aiopg/broker.py
Empty file.
28 changes: 28 additions & 0 deletions src/taskiq_pg/aiopg/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,31 @@
DELETE_RESULT_QUERY = """
DELETE FROM {} WHERE task_id = %s
"""

CREATE_SCHEDULES_TABLE_QUERY = """
CREATE TABLE IF NOT EXISTS {} (
id UUID PRIMARY KEY,
task_name VARCHAR(100) NOT NULL,
schedule JSONB NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
"""

INSERT_SCHEDULE_QUERY = """
INSERT INTO {} (id, task_name, schedule)
VALUES (%s, %s, %s)
ON CONFLICT (id) DO UPDATE
SET task_name = EXCLUDED.task_name,
schedule = EXCLUDED.schedule,
updated_at = NOW();
"""

SELECT_SCHEDULES_QUERY = """
SELECT id, task_name, schedule
FROM {};
"""

DELETE_ALL_SCHEDULES_QUERY = """
DELETE FROM {};
"""
124 changes: 124 additions & 0 deletions src/taskiq_pg/aiopg/schedule_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import uuid
from logging import getLogger

from aiopg import Pool, create_pool
from pydantic import ValidationError
from taskiq import ScheduledTask

from taskiq_pg import exceptions
from taskiq_pg._internal import BasePostgresScheduleSource
from taskiq_pg.aiopg.queries import (
CREATE_SCHEDULES_TABLE_QUERY,
DELETE_ALL_SCHEDULES_QUERY,
INSERT_SCHEDULE_QUERY,
SELECT_SCHEDULES_QUERY,
)


logger = getLogger("taskiq_pg.aiopg_schedule_source")


class AiopgScheduleSource(BasePostgresScheduleSource):
"""Schedule source that uses aiopg to store schedules in PostgreSQL."""

_database_pool: Pool

async def _update_schedules_on_startup(self, schedules: list[ScheduledTask]) -> None:
"""Update schedules in the database on startup: truncate table and insert new ones."""
async with self._database_pool.acquire() as connection, connection.cursor() as cursor:
await cursor.execute(DELETE_ALL_SCHEDULES_QUERY.format(self._table_name))
for schedule in schedules:
await cursor.execute(
INSERT_SCHEDULE_QUERY.format(self._table_name),
[
schedule.schedule_id,
schedule.task_name,
schedule.model_dump_json(
exclude={"schedule_id", "task_name"},
),
],
)

def _get_schedules_from_broker_tasks(self) -> list[ScheduledTask]:
"""Extract schedules from the broker's registered tasks."""
scheduled_tasks_for_creation: list[ScheduledTask] = []
for task_name, task in self._broker.get_all_tasks().items():
if "schedule" not in task.labels:
logger.debug("Task %s has no schedule, skipping", task_name)
continue
if not isinstance(task.labels["schedule"], list):
logger.warning(
"Schedule for task %s is not a list, skipping",
task_name,
)
continue
for schedule in task.labels["schedule"]:
try:
new_schedule = ScheduledTask.model_validate(
{
"task_name": task_name,
"labels": schedule.get("labels", {}),
"args": schedule.get("args", []),
"kwargs": schedule.get("kwargs", {}),
"schedule_id": str(uuid.uuid4()),
"cron": schedule.get("cron", None),
"cron_offset": schedule.get("cron_offset", None),
"time": schedule.get("time", None),
},
)
scheduled_tasks_for_creation.append(new_schedule)
except ValidationError:
logger.exception(
"Schedule for task %s is not valid, skipping",
task_name,
)
continue
return scheduled_tasks_for_creation

async def startup(self) -> None:
"""
Initialize the schedule source.

Construct new connection pool, create new table for schedules if not exists
and fill table with schedules from task labels.
"""
try:
self._database_pool = await create_pool(
dsn=self.dsn,
**self._connect_kwargs,
)
async with self._database_pool.acquire() as connection, connection.cursor() as cursor:
await cursor.execute(CREATE_SCHEDULES_TABLE_QUERY.format(self._table_name))
scheduled_tasks_for_creation = self._get_schedules_from_broker_tasks()
await self._update_schedules_on_startup(scheduled_tasks_for_creation)
except Exception as error:
raise exceptions.DatabaseConnectionError(str(error)) from error

async def shutdown(self) -> None:
"""Close the connection pool."""
if getattr(self, "_database_pool", None) is not None:
self._database_pool.close()

async def get_schedules(self) -> list["ScheduledTask"]:
"""Fetch schedules from the database."""
async with self._database_pool.acquire() as connection, connection.cursor() as cursor:
await cursor.execute(
SELECT_SCHEDULES_QUERY.format(self._table_name),
)
schedules, rows = [], await cursor.fetchall()
for schedule_id, task_name, schedule in rows:
schedules.append(
ScheduledTask.model_validate(
{
"schedule_id": str(schedule_id),
"task_name": task_name,
"labels": schedule["labels"],
"args": schedule["args"],
"kwargs": schedule["kwargs"],
"cron": schedule["cron"],
"cron_offset": schedule["cron_offset"],
"time": schedule["time"],
},
),
)
return schedules
51 changes: 5 additions & 46 deletions src/taskiq_pg/asyncpg/schedule_source.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import json
import typing as tp
import uuid
from logging import getLogger

import asyncpg
from pydantic import ValidationError
from taskiq import ScheduledTask, ScheduleSource
from taskiq.abc.broker import AsyncBroker
from taskiq import ScheduledTask

from taskiq_pg._internal import BasePostgresScheduleSource
from taskiq_pg.asyncpg.queries import (
CREATE_SCHEDULES_TABLE_QUERY,
DELETE_ALL_SCHEDULES_QUERY,
Expand All @@ -19,57 +18,16 @@
logger = getLogger("taskiq_pg.asyncpg_schedule_source")


class AsyncpgScheduleSource(ScheduleSource):
class AsyncpgScheduleSource(BasePostgresScheduleSource):
"""Schedule source that uses asyncpg to store schedules in PostgreSQL."""

_database_pool: "asyncpg.Pool[asyncpg.Record]"

def __init__(
self,
broker: AsyncBroker,
dsn: str | tp.Callable[[], str] = "postgresql://postgres:postgres@localhost:5432/postgres",
table_name: str = "taskiq_schedules",
**connect_kwargs: tp.Any,
) -> None:
"""
Initialize the PostgreSQL scheduler source.

Sets up a scheduler source that stores scheduled tasks in a PostgreSQL database.
This scheduler source manages task schedules, allowing for persistent storage and retrieval of scheduled tasks
across application restarts.

Args:
dsn: PostgreSQL connection string
table_name: Name of the table to store scheduled tasks. Will be created automatically if it doesn't exist.
broker: The TaskIQ broker instance to use for finding and managing tasks.
Required if startup_schedule is provided.
**connect_kwargs: Additional keyword arguments passed to the database connection pool.

"""
self._broker: tp.Final = broker
self._dsn: tp.Final = dsn
self._table_name: tp.Final = table_name
self._connect_kwargs: tp.Final = connect_kwargs

@property
def dsn(self) -> str | None:
"""
Get the DSN string.

Returns the DSN string or None if not set.
"""
if callable(self._dsn):
return self._dsn()
return self._dsn

async def _update_schedules_on_startup(self, schedules: list[ScheduledTask]) -> None:
"""Update schedules in the database on startup: trancate table and insert new ones."""
"""Update schedules in the database on startup: truncate table and insert new ones."""
async with self._database_pool.acquire() as connection, connection.transaction():
await connection.execute(DELETE_ALL_SCHEDULES_QUERY.format(self._table_name))
for schedule in schedules:
schedule.model_dump_json(
exclude={"schedule_id", "task_name"},
)
await self._database_pool.execute(
INSERT_SCHEDULE_QUERY.format(self._table_name),
str(schedule.schedule_id),
Expand All @@ -91,6 +49,7 @@ def _get_schedules_from_broker_tasks(self) -> list[ScheduledTask]:
"Schedule for task %s is not a list, skipping",
task_name,
)
continue
for schedule in task.labels["schedule"]:
try:
new_schedule = ScheduledTask.model_validate(
Expand Down
2 changes: 2 additions & 0 deletions src/taskiq_pg/psqlpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from taskiq_pg.psqlpy.broker import PSQLPyBroker
from taskiq_pg.psqlpy.result_backend import PSQLPyResultBackend
from taskiq_pg.psqlpy.schedule_source import PSQLPyScheduleSource


__all__ = [
"PSQLPyBroker",
"PSQLPyResultBackend",
"PSQLPyScheduleSource",
]
28 changes: 28 additions & 0 deletions src/taskiq_pg/psqlpy/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,31 @@
CLAIM_MESSAGE_QUERY = "UPDATE {} SET status = 'processing' WHERE id = $1 AND status = 'pending' RETURNING *"

DELETE_MESSAGE_QUERY = "DELETE FROM {} WHERE id = $1"

CREATE_SCHEDULES_TABLE_QUERY = """
CREATE TABLE IF NOT EXISTS {} (
id UUID PRIMARY KEY,
task_name VARCHAR(100) NOT NULL,
schedule JSONB NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
"""

INSERT_SCHEDULE_QUERY = """
INSERT INTO {} (id, task_name, schedule)
VALUES ($1, $2, $3)
ON CONFLICT (id) DO UPDATE
SET task_name = EXCLUDED.task_name,
schedule = EXCLUDED.schedule,
updated_at = NOW();
"""

SELECT_SCHEDULES_QUERY = """
SELECT id, task_name, schedule
FROM {};
"""

DELETE_ALL_SCHEDULES_QUERY = """
DELETE FROM {};
"""
Loading