From 91ce7defc7559ac0f05dc3c4bde03c68114e7546 Mon Sep 17 00:00:00 2001 From: Anfimov Dima Date: Fri, 17 Oct 2025 18:23:31 +0200 Subject: [PATCH] feat: add optional methods for schedule sources --- Makefile | 4 + docs/tutorial/schedule_source.md | 79 +++++++++++++++++++ pyproject.toml | 8 +- src/taskiq_pg/_internal/broker.py | 2 +- src/taskiq_pg/_internal/schedule_source.py | 48 ++++++++++++ src/taskiq_pg/aiopg/queries.py | 4 + src/taskiq_pg/aiopg/schedule_source.py | 80 ++++++++++---------- src/taskiq_pg/asyncpg/queries.py | 4 + src/taskiq_pg/asyncpg/schedule_source.py | 76 +++++++++---------- src/taskiq_pg/psqlpy/queries.py | 4 + src/taskiq_pg/psqlpy/schedule_source.py | 81 ++++++++++---------- src/taskiq_pg/psycopg/queries.py | 4 + src/taskiq_pg/psycopg/schedule_source.py | 88 +++++++++++----------- tests/integration/test_schedule_source.py | 86 ++++++++++++++++++++- uv.lock | 29 +++++++ 15 files changed, 433 insertions(+), 164 deletions(-) diff --git a/Makefile b/Makefile index 6e95d9d..74328c1 100644 --- a/Makefile +++ b/Makefile @@ -29,6 +29,10 @@ init: ## Install all project dependencies with extras @$(MAKE) check_venv @uv sync --all-extras +.PHONY: run_docs +run_docs: ## Run documentation server + @uv run mkdocs serve --livereload + .PHONY: run_infra run_infra: ## Run rabbitmq in docker for integration tests @docker compose -f docker-compose.yml up -d diff --git a/docs/tutorial/schedule_source.md b/docs/tutorial/schedule_source.md index 7df5efd..b1131e6 100644 --- a/docs/tutorial/schedule_source.md +++ b/docs/tutorial/schedule_source.md @@ -2,6 +2,85 @@ title: Schedule Source --- +## Basic usage + +The easiest way to schedule task with this library is to add `schedule` label to task. Schedule source will automatically +parse this label and add new schedule to database on start of scheduler. + +You can define your scheduled task like this: + + ```python + import asyncio + from taskiq import TaskiqScheduler + from taskiq_pg.asyncpg import AsyncpgBroker, AsyncpgScheduleSource + + + dsn = "postgres://taskiq_postgres:look_in_vault@localhost:5432/taskiq_postgres" + broker = AsyncpgBroker(dsn) + scheduler = TaskiqScheduler( + broker=broker, + sources=[AsyncpgScheduleSource( + dsn=dsn, + broker=broker, + )], + ) + + + @broker.task( + task_name="solve_all_problems", + schedule=[ + { + "cron": "*/1 * * * *", # type: str, either cron or time should be specified. + "cron_offset": None, # type: str | None, can be omitted. For example "Europe/Berlin". + "time": None, # type: datetime | None, either cron or time should be specified. + "args": [], # type list[Any] | None, can be omitted. + "kwargs": {}, # type: dict[str, Any] | None, can be omitted. + "labels": {}, # type: dict[str, Any] | None, can be omitted. + }, + ], + ) + async def best_task_ever() -> None: + """Solve all problems in the world.""" + await asyncio.sleep(2) + print("All problems are solved!") + ``` + + +## Adding schedule in runtime + +You can also add schedules in runtime using `add_schedule` method of the schedule source: + + + ```python + import asyncio + from taskiq import TaskiqScheduler, ScheduledTask + from taskiq_pg.asyncpg import AsyncpgBroker, AsyncpgScheduleSource + + + dsn = "postgres://taskiq_postgres:look_in_vault@localhost:5432/taskiq_postgres" + broker = AsyncpgBroker(dsn) + schedule_source = AsyncpgScheduleSource( + dsn=dsn, + broker=broker, + ) + scheduler = TaskiqScheduler( + broker=broker, + sources=[schedule_source], + ) + + + @broker.task( + task_name="solve_all_problems", + ) + async def best_task_ever() -> None: + """Solve all problems in the world.""" + await asyncio.sleep(2) + print("All problems are solved!") + + # Call this function somewhere in your code to add new schedule + async def add_new_schedule() -> None: + await schedule_source.add_schedule(ScheduledTask(...)) + ``` ## Using multiple schedules diff --git a/pyproject.toml b/pyproject.toml index e67e8b4..61dc52a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ lint = [ "asyncpg-stubs>=0.30.2", ] test = [ + "polyfactory>=2.22.2", "pytest>=8.4.2", "pytest-asyncio>=1.1.0", "pytest-cov>=7.0.0", @@ -125,10 +126,8 @@ select = ["ALL"] ignore = [ # TODO: enable this rules "TRY301", - "PLR0913", - "D401", "ANN401", - "PERF203", + # "PERF203", # boolean args @@ -174,6 +173,9 @@ ignore = [ "INP001", ] +[tool.ruff.lint.pydocstyle] +convention = "google" + [tool.ruff.lint.isort] known-local-folder = ["taskiq_pg"] lines-after-imports = 2 diff --git a/src/taskiq_pg/_internal/broker.py b/src/taskiq_pg/_internal/broker.py index 1e4297a..ed1afab 100644 --- a/src/taskiq_pg/_internal/broker.py +++ b/src/taskiq_pg/_internal/broker.py @@ -16,7 +16,7 @@ class BasePostgresBroker(AsyncBroker, abc.ABC): """Base class for Postgres brokers.""" - def __init__( + def __init__( # noqa: PLR0913 self, dsn: str | tp.Callable[[], str] = "postgresql://postgres:postgres@localhost:5432/postgres", result_backend: AsyncResultBackend[_T] | None = None, diff --git a/src/taskiq_pg/_internal/schedule_source.py b/src/taskiq_pg/_internal/schedule_source.py index 02ab4c6..d1c2b61 100644 --- a/src/taskiq_pg/_internal/schedule_source.py +++ b/src/taskiq_pg/_internal/schedule_source.py @@ -1,14 +1,21 @@ from __future__ import annotations import typing as tp +import uuid +from logging import getLogger +from pydantic import ValidationError from taskiq import ScheduleSource +from taskiq.scheduler.scheduled_task import ScheduledTask if tp.TYPE_CHECKING: from taskiq.abc.broker import AsyncBroker +logger = getLogger("taskiq_pg") + + class BasePostgresScheduleSource(ScheduleSource): def __init__( self, @@ -47,3 +54,44 @@ def dsn(self) -> str | None: if callable(self._dsn): return self._dsn() return self._dsn + + def extract_scheduled_tasks_from_broker(self) -> list[ScheduledTask]: + """ + Extract schedules from tasks that were registered in broker. + + Returns: + A list of ScheduledTask instances extracted from the task's labels. + """ + scheduled_tasks_for_creation: list[ScheduledTask] = [] + for task_name, task in self._broker.get_all_tasks().items(): + if "schedule" not in task.labels: + logger.debug("Task %s has no schedule, skipping", task_name) + continue + if not isinstance(task.labels["schedule"], list): + logger.warning( + "Schedule for task %s is not a list, skipping", + task_name, + ) + continue + for schedule in task.labels["schedule"]: + try: + new_schedule = ScheduledTask.model_validate( + { + "task_name": task_name, + "labels": schedule.get("labels", {}), + "args": schedule.get("args", []), + "kwargs": schedule.get("kwargs", {}), + "schedule_id": str(uuid.uuid4()), + "cron": schedule.get("cron", None), + "cron_offset": schedule.get("cron_offset", None), + "time": schedule.get("time", None), + }, + ) + scheduled_tasks_for_creation.append(new_schedule) + except ValidationError: # noqa: PERF203 + logger.exception( + "Schedule for task %s is not valid, skipping", + task_name, + ) + continue + return scheduled_tasks_for_creation diff --git a/src/taskiq_pg/aiopg/queries.py b/src/taskiq_pg/aiopg/queries.py index 8868808..f94277b 100644 --- a/src/taskiq_pg/aiopg/queries.py +++ b/src/taskiq_pg/aiopg/queries.py @@ -57,3 +57,7 @@ DELETE_ALL_SCHEDULES_QUERY = """ DELETE FROM {}; """ + +DELETE_SCHEDULE_QUERY = """ +DELETE FROM {} WHERE id = %s; +""" diff --git a/src/taskiq_pg/aiopg/schedule_source.py b/src/taskiq_pg/aiopg/schedule_source.py index f8b4a1a..5e99494 100644 --- a/src/taskiq_pg/aiopg/schedule_source.py +++ b/src/taskiq_pg/aiopg/schedule_source.py @@ -1,8 +1,6 @@ -import uuid from logging import getLogger from aiopg import Pool, create_pool -from pydantic import ValidationError from taskiq import ScheduledTask from taskiq_pg import exceptions @@ -10,6 +8,7 @@ from taskiq_pg.aiopg.queries import ( CREATE_SCHEDULES_TABLE_QUERY, DELETE_ALL_SCHEDULES_QUERY, + DELETE_SCHEDULE_QUERY, INSERT_SCHEDULE_QUERY, SELECT_SCHEDULES_QUERY, ) @@ -39,42 +38,6 @@ async def _update_schedules_on_startup(self, schedules: list[ScheduledTask]) -> ], ) - def _get_schedules_from_broker_tasks(self) -> list[ScheduledTask]: - """Extract schedules from the broker's registered tasks.""" - scheduled_tasks_for_creation: list[ScheduledTask] = [] - for task_name, task in self._broker.get_all_tasks().items(): - if "schedule" not in task.labels: - logger.debug("Task %s has no schedule, skipping", task_name) - continue - if not isinstance(task.labels["schedule"], list): - logger.warning( - "Schedule for task %s is not a list, skipping", - task_name, - ) - continue - for schedule in task.labels["schedule"]: - try: - new_schedule = ScheduledTask.model_validate( - { - "task_name": task_name, - "labels": schedule.get("labels", {}), - "args": schedule.get("args", []), - "kwargs": schedule.get("kwargs", {}), - "schedule_id": str(uuid.uuid4()), - "cron": schedule.get("cron", None), - "cron_offset": schedule.get("cron_offset", None), - "time": schedule.get("time", None), - }, - ) - scheduled_tasks_for_creation.append(new_schedule) - except ValidationError: - logger.exception( - "Schedule for task %s is not valid, skipping", - task_name, - ) - continue - return scheduled_tasks_for_creation - async def startup(self) -> None: """ Initialize the schedule source. @@ -89,7 +52,7 @@ async def startup(self) -> None: ) async with self._database_pool.acquire() as connection, connection.cursor() as cursor: await cursor.execute(CREATE_SCHEDULES_TABLE_QUERY.format(self._table_name)) - scheduled_tasks_for_creation = self._get_schedules_from_broker_tasks() + scheduled_tasks_for_creation = self.extract_scheduled_tasks_from_broker() await self._update_schedules_on_startup(scheduled_tasks_for_creation) except Exception as error: raise exceptions.DatabaseConnectionError(str(error)) from error @@ -122,3 +85,42 @@ async def get_schedules(self) -> list["ScheduledTask"]: ), ) return schedules + + async def add_schedule(self, schedule: "ScheduledTask") -> None: + """ + Add a new schedule. + + Args: + schedule: schedule to add. + """ + async with self._database_pool.acquire() as connection, connection.cursor() as cursor: + await cursor.execute( + INSERT_SCHEDULE_QUERY.format(self._table_name), + [ + schedule.schedule_id, + schedule.task_name, + schedule.model_dump_json( + exclude={"schedule_id", "task_name"}, + ), + ], + ) + + async def delete_schedule(self, schedule_id: str) -> None: + """ + Method to delete schedule by id. + + This is useful for schedule cancelation. + + Args: + schedule_id: id of schedule to delete. + """ + async with self._database_pool.acquire() as connection, connection.cursor() as cursor: + await cursor.execute( + DELETE_SCHEDULE_QUERY.format(self._table_name), + [schedule_id], + ) + + async def post_send(self, task: ScheduledTask) -> None: + """Delete a task after it's completed.""" + if task.time is not None: + await self.delete_schedule(task.schedule_id) diff --git a/src/taskiq_pg/asyncpg/queries.py b/src/taskiq_pg/asyncpg/queries.py index d7185a0..acc03cd 100644 --- a/src/taskiq_pg/asyncpg/queries.py +++ b/src/taskiq_pg/asyncpg/queries.py @@ -79,3 +79,7 @@ DELETE_ALL_SCHEDULES_QUERY = """ DELETE FROM {}; """ + +DELETE_SCHEDULE_QUERY = """ +DELETE FROM {} WHERE id = $1; +""" diff --git a/src/taskiq_pg/asyncpg/schedule_source.py b/src/taskiq_pg/asyncpg/schedule_source.py index fb0fe36..eb14d8e 100644 --- a/src/taskiq_pg/asyncpg/schedule_source.py +++ b/src/taskiq_pg/asyncpg/schedule_source.py @@ -1,15 +1,14 @@ import json -import uuid from logging import getLogger import asyncpg -from pydantic import ValidationError from taskiq import ScheduledTask from taskiq_pg._internal import BasePostgresScheduleSource from taskiq_pg.asyncpg.queries import ( CREATE_SCHEDULES_TABLE_QUERY, DELETE_ALL_SCHEDULES_QUERY, + DELETE_SCHEDULE_QUERY, INSERT_SCHEDULE_QUERY, SELECT_SCHEDULES_QUERY, ) @@ -37,42 +36,6 @@ async def _update_schedules_on_startup(self, schedules: list[ScheduledTask]) -> ), ) - def _get_schedules_from_broker_tasks(self) -> list[ScheduledTask]: - """Extract schedules from the broker's registered tasks.""" - scheduled_tasks_for_creation: list[ScheduledTask] = [] - for task_name, task in self._broker.get_all_tasks().items(): - if "schedule" not in task.labels: - logger.debug("Task %s has no schedule, skipping", task_name) - continue - if not isinstance(task.labels["schedule"], list): - logger.warning( - "Schedule for task %s is not a list, skipping", - task_name, - ) - continue - for schedule in task.labels["schedule"]: - try: - new_schedule = ScheduledTask.model_validate( - { - "task_name": task_name, - "labels": schedule.get("labels", {}), - "args": schedule.get("args", []), - "kwargs": schedule.get("kwargs", {}), - "schedule_id": str(uuid.uuid4()), - "cron": schedule.get("cron", None), - "cron_offset": schedule.get("cron_offset", None), - "time": schedule.get("time", None), - }, - ) - scheduled_tasks_for_creation.append(new_schedule) - except ValidationError: - logger.exception( - "Schedule for task %s is not valid, skipping", - task_name, - ) - continue - return scheduled_tasks_for_creation - async def startup(self) -> None: """ Initialize the schedule source. @@ -89,7 +52,7 @@ async def startup(self) -> None: self._table_name, ), ) - scheduled_tasks_for_creation = self._get_schedules_from_broker_tasks() + scheduled_tasks_for_creation = self.extract_scheduled_tasks_from_broker() await self._update_schedules_on_startup(scheduled_tasks_for_creation) async def shutdown(self) -> None: @@ -121,3 +84,38 @@ async def get_schedules(self) -> list["ScheduledTask"]: ), ) return schedules + + async def add_schedule(self, schedule: "ScheduledTask") -> None: + """ + Add a new schedule. + + Args: + schedule: schedule to add. + """ + await self._database_pool.execute( + INSERT_SCHEDULE_QUERY.format(self._table_name), + str(schedule.schedule_id), + schedule.task_name, + schedule.model_dump_json( + exclude={"schedule_id", "task_name"}, + ), + ) + + async def delete_schedule(self, schedule_id: str) -> None: + """ + Method to delete schedule by id. + + This is useful for schedule cancelation. + + Args: + schedule_id: id of schedule to delete. + """ + await self._database_pool.execute( + DELETE_SCHEDULE_QUERY.format(self._table_name), + schedule_id, + ) + + async def post_send(self, task: ScheduledTask) -> None: + """Delete a task after it's completed.""" + if task.time is not None: + await self.delete_schedule(task.schedule_id) diff --git a/src/taskiq_pg/psqlpy/queries.py b/src/taskiq_pg/psqlpy/queries.py index 3555c47..e443f72 100644 --- a/src/taskiq_pg/psqlpy/queries.py +++ b/src/taskiq_pg/psqlpy/queries.py @@ -79,3 +79,7 @@ DELETE_ALL_SCHEDULES_QUERY = """ DELETE FROM {}; """ + +DELETE_SCHEDULE_QUERY = """ +DELETE FROM {} WHERE id = $1; +""" diff --git a/src/taskiq_pg/psqlpy/schedule_source.py b/src/taskiq_pg/psqlpy/schedule_source.py index 7fb9fdc..ce714ca 100644 --- a/src/taskiq_pg/psqlpy/schedule_source.py +++ b/src/taskiq_pg/psqlpy/schedule_source.py @@ -3,13 +3,13 @@ from psqlpy import ConnectionPool from psqlpy.extra_types import JSONB -from pydantic import ValidationError from taskiq import ScheduledTask from taskiq_pg._internal import BasePostgresScheduleSource from taskiq_pg.psqlpy.queries import ( CREATE_SCHEDULES_TABLE_QUERY, DELETE_ALL_SCHEDULES_QUERY, + DELETE_SCHEDULE_QUERY, INSERT_SCHEDULE_QUERY, SELECT_SCHEDULES_QUERY, ) @@ -46,42 +46,6 @@ async def _update_schedules_on_startup(self, schedules: list[ScheduledTask]) -> data_to_insert, ) - def _get_schedules_from_broker_tasks(self) -> list[ScheduledTask]: - """Extract schedules from the broker's registered tasks.""" - scheduled_tasks_for_creation: list[ScheduledTask] = [] - for task_name, task in self._broker.get_all_tasks().items(): - if "schedule" not in task.labels: - logger.debug("Task %s has no schedule, skipping", task_name) - continue - if not isinstance(task.labels["schedule"], list): - logger.warning( - "Schedule for task %s is not a list, skipping", - task_name, - ) - continue - for schedule in task.labels["schedule"]: - try: - new_schedule = ScheduledTask.model_validate( - { - "task_name": task_name, - "labels": schedule.get("labels", {}), - "args": schedule.get("args", []), - "kwargs": schedule.get("kwargs", {}), - "schedule_id": str(uuid.uuid4()), - "cron": schedule.get("cron", None), - "cron_offset": schedule.get("cron_offset", None), - "time": schedule.get("time", None), - }, - ) - scheduled_tasks_for_creation.append(new_schedule) - except ValidationError: - logger.exception( - "Schedule for task %s is not valid, skipping", - task_name, - ) - continue - return scheduled_tasks_for_creation - async def startup(self) -> None: """ Initialize the schedule source. @@ -99,7 +63,7 @@ async def startup(self) -> None: self._table_name, ), ) - scheduled_tasks_for_creation = self._get_schedules_from_broker_tasks() + scheduled_tasks_for_creation = self.extract_scheduled_tasks_from_broker() await self._update_schedules_on_startup(scheduled_tasks_for_creation) async def shutdown(self) -> None: @@ -131,3 +95,44 @@ async def get_schedules(self) -> list["ScheduledTask"]: ), ) return schedules + + async def add_schedule(self, schedule: "ScheduledTask") -> None: + """ + Add a new schedule. + + Args: + schedule: schedule to add. + """ + async with self._database_pool.acquire() as connection: + schedule_dict = schedule.model_dump( + mode="json", + exclude={"schedule_id", "task_name"}, + ) + await connection.execute( + INSERT_SCHEDULE_QUERY.format(self._table_name), + [ + uuid.UUID(schedule.schedule_id), + schedule.task_name, + JSONB(schedule_dict), + ] + ) + + async def delete_schedule(self, schedule_id: str) -> None: + """ + Method to delete schedule by id. + + This is useful for schedule cancelation. + + Args: + schedule_id: id of schedule to delete. + """ + async with self._database_pool.acquire() as connection: + await connection.execute( + DELETE_SCHEDULE_QUERY.format(self._table_name), + [uuid.UUID(schedule_id)], + ) + + async def post_send(self, task: ScheduledTask) -> None: + """Delete a task after it's completed.""" + if task.time is not None: + await self.delete_schedule(task.schedule_id) diff --git a/src/taskiq_pg/psycopg/queries.py b/src/taskiq_pg/psycopg/queries.py index 3aff0a9..9d51a41 100644 --- a/src/taskiq_pg/psycopg/queries.py +++ b/src/taskiq_pg/psycopg/queries.py @@ -79,3 +79,7 @@ DELETE_ALL_SCHEDULES_QUERY = """ DELETE FROM {}; """ + +DELETE_SCHEDULE_QUERY = """ +DELETE FROM {} WHERE id = %s; +""" diff --git a/src/taskiq_pg/psycopg/schedule_source.py b/src/taskiq_pg/psycopg/schedule_source.py index 4a35e90..efc63e9 100644 --- a/src/taskiq_pg/psycopg/schedule_source.py +++ b/src/taskiq_pg/psycopg/schedule_source.py @@ -1,14 +1,15 @@ import uuid from logging import getLogger +from psycopg import sql from psycopg_pool import AsyncConnectionPool -from pydantic import ValidationError from taskiq import ScheduledTask from taskiq_pg._internal import BasePostgresScheduleSource from taskiq_pg.psycopg.queries import ( CREATE_SCHEDULES_TABLE_QUERY, DELETE_ALL_SCHEDULES_QUERY, + DELETE_SCHEDULE_QUERY, INSERT_SCHEDULE_QUERY, SELECT_SCHEDULES_QUERY, ) @@ -25,7 +26,7 @@ class PsycopgScheduleSource(BasePostgresScheduleSource): async def _update_schedules_on_startup(self, schedules: list[ScheduledTask]) -> None: """Update schedules in the database on startup: truncate table and insert new ones.""" async with self._database_pool.connection() as connection, connection.cursor() as cursor: - await cursor.execute(DELETE_ALL_SCHEDULES_QUERY.format(self._table_name)) + await cursor.execute(sql.SQL(DELETE_ALL_SCHEDULES_QUERY).format(sql.Identifier(self._table_name))) data_to_insert: list = [ [ uuid.UUID(schedule.schedule_id), @@ -37,46 +38,10 @@ async def _update_schedules_on_startup(self, schedules: list[ScheduledTask]) -> for schedule in schedules ] await cursor.executemany( - INSERT_SCHEDULE_QUERY.format(self._table_name), + sql.SQL(INSERT_SCHEDULE_QUERY).format(sql.Identifier(self._table_name)), data_to_insert, ) - def _get_schedules_from_broker_tasks(self) -> list[ScheduledTask]: - """Extract schedules from the broker's registered tasks.""" - scheduled_tasks_for_creation: list[ScheduledTask] = [] - for task_name, task in self._broker.get_all_tasks().items(): - if "schedule" not in task.labels: - logger.debug("Task %s has no schedule, skipping", task_name) - continue - if not isinstance(task.labels["schedule"], list): - logger.warning( - "Schedule for task %s is not a list, skipping", - task_name, - ) - continue - for schedule in task.labels["schedule"]: - try: - new_schedule = ScheduledTask.model_validate( - { - "task_name": task_name, - "labels": schedule.get("labels", {}), - "args": schedule.get("args", []), - "kwargs": schedule.get("kwargs", {}), - "schedule_id": str(uuid.uuid4()), - "cron": schedule.get("cron", None), - "cron_offset": schedule.get("cron_offset", None), - "time": schedule.get("time", None), - }, - ) - scheduled_tasks_for_creation.append(new_schedule) - except ValidationError: - logger.exception( - "Schedule for task %s is not valid, skipping", - task_name, - ) - continue - return scheduled_tasks_for_creation - async def startup(self) -> None: """ Initialize the schedule source. @@ -93,9 +58,9 @@ async def startup(self) -> None: async with self._database_pool.connection() as connection, connection.cursor() as cursor: await cursor.execute( - CREATE_SCHEDULES_TABLE_QUERY.format(self._table_name), + sql.SQL(CREATE_SCHEDULES_TABLE_QUERY).format(sql.Identifier(self._table_name)), ) - scheduled_tasks_for_creation = self._get_schedules_from_broker_tasks() + scheduled_tasks_for_creation = self.extract_scheduled_tasks_from_broker() await self._update_schedules_on_startup(scheduled_tasks_for_creation) async def shutdown(self) -> None: @@ -108,7 +73,7 @@ async def get_schedules(self) -> list["ScheduledTask"]: schedules = [] async with self._database_pool.connection() as connection, connection.cursor() as cursor: rows_with_schedules = await cursor.execute( - SELECT_SCHEDULES_QUERY.format(self._table_name), + sql.SQL(SELECT_SCHEDULES_QUERY).format(sql.Identifier(self._table_name)), ) rows = await rows_with_schedules.fetchall() for schedule_id, task_name, schedule in rows: @@ -127,3 +92,42 @@ async def get_schedules(self) -> list["ScheduledTask"]: ), ) return schedules + + async def add_schedule(self, schedule: "ScheduledTask") -> None: + """ + Add a new schedule. + + Args: + schedule: schedule to add. + """ + async with self._database_pool.connection() as connection, connection.cursor() as cursor: + await cursor.execute( + sql.SQL(INSERT_SCHEDULE_QUERY).format(sql.Identifier(self._table_name)), + [ + uuid.UUID(schedule.schedule_id), + schedule.task_name, + schedule.model_dump_json( + exclude={"schedule_id", "task_name"}, + ), + ] + ) + + async def delete_schedule(self, schedule_id: str) -> None: + """ + Method to delete schedule by id. + + This is useful for schedule cancelation. + + Args: + schedule_id: id of schedule to delete. + """ + async with self._database_pool.connection() as connection, connection.cursor() as cursor: + await cursor.execute( + sql.SQL(DELETE_SCHEDULE_QUERY).format(sql.Identifier(self._table_name)), + [schedule_id], + ) + + async def post_send(self, task: ScheduledTask) -> None: + """Delete a task after it's completed.""" + if task.time is not None: + await self.delete_schedule(task.schedule_id) diff --git a/tests/integration/test_schedule_source.py b/tests/integration/test_schedule_source.py index 0179d9f..031400f 100644 --- a/tests/integration/test_schedule_source.py +++ b/tests/integration/test_schedule_source.py @@ -1,11 +1,14 @@ from __future__ import annotations import typing as tp +import uuid from contextlib import asynccontextmanager from datetime import timedelta import pytest +from polyfactory.factories.pydantic_factory import ModelFactory from sqlalchemy_utils.types.enriched_datetime.arrow_datetime import datetime +from taskiq import ScheduledTask from taskiq_pg.aiopg import AiopgScheduleSource from taskiq_pg.asyncpg import AsyncpgScheduleSource @@ -14,11 +17,22 @@ if tp.TYPE_CHECKING: - from taskiq import ScheduledTask - from taskiq_pg._internal import BasePostgresBroker, BasePostgresScheduleSource + +class ScheduledTaskFactory(ModelFactory[ScheduledTask]): + """Factory for ScheduledTask.""" + + __model__ = ScheduledTask + __check_model__ = True + + @classmethod + def schedule_id(cls) -> str: + """Generate unique schedule ID.""" + return uuid.uuid4().hex + + @asynccontextmanager async def _get_schedule_source( schedule_source_class: type[BasePostgresScheduleSource], @@ -127,3 +141,71 @@ async def test_when_labels_contain_schedules__then_get_schedules_returns_schedul assert {item.task_name for item in schedules} == {"tests:one_schedule", "tests:two_schedules"} assert {item.time for item in schedules} == {datetime(2024, 1, 1, 12, 0, 0), None} assert all(item.schedule_id is not None for item in schedules) + + +@pytest.mark.integration +@pytest.mark.parametrize( + "schedule_source_class", + [ + PSQLPyScheduleSource, + AiopgScheduleSource, + AsyncpgScheduleSource, + PsycopgScheduleSource, + ], +) +async def test_when_call_add_schedule__then_schedule_creates( + pg_dsn: str, + broker_with_scheduled_tasks: PSQLPyBroker, + schedule_source_class: type[PSQLPyScheduleSource | AiopgScheduleSource | AsyncpgScheduleSource], +) -> None: + # Given + new_schedule = ScheduledTaskFactory.build(task_name="tests:added_schedule", cron="*/5 * * * *") + async with _get_schedule_source(schedule_source_class, broker_with_scheduled_tasks, pg_dsn) as schedule_source: + await schedule_source.startup() + + # When + await schedule_source.add_schedule(new_schedule) + + # Then + schedules: list[ScheduledTask] = await schedule_source.get_schedules() + assert len(schedules) == 4 + added_schedule = None + for task in schedules: + if task.task_name == "tests:added_schedule": + added_schedule = task + break + assert added_schedule is not None + + + +@pytest.mark.integration +@pytest.mark.parametrize( + "schedule_source_class", + [ + PSQLPyScheduleSource, + AiopgScheduleSource, + AsyncpgScheduleSource, + PsycopgScheduleSource, + ], +) +async def test_when_call_delete_schedule__then_schedule_deleted( + pg_dsn: str, + broker_with_scheduled_tasks: PSQLPyBroker, + schedule_source_class: type[PSQLPyScheduleSource | AiopgScheduleSource | AsyncpgScheduleSource], +) -> None: + # Given + async with _get_schedule_source(schedule_source_class, broker_with_scheduled_tasks, pg_dsn) as schedule_source: + await schedule_source.startup() + schedules: list[ScheduledTask] = await schedule_source.get_schedules() + schedule_id_to_delete = str(schedules[0].schedule_id) + + # When + await schedule_source.delete_schedule(schedule_id_to_delete) + + # Then + schedules: list[ScheduledTask] = await schedule_source.get_schedules() + assert len(schedules) == 2 + assert all( + task.schedule_id != schedule_id_to_delete + for task in schedules + ) diff --git a/uv.lock b/uv.lock index a2f4699..a9d832c 100644 --- a/uv.lock +++ b/uv.lock @@ -392,6 +392,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453, upload-time = "2024-07-12T22:25:58.476Z" }, ] +[[package]] +name = "faker" +version = "37.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tzdata" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/4b/ca43f6bbcef63deb8ac01201af306388670a172587169aab3b192f7490f0/faker-37.11.0.tar.gz", hash = "sha256:22969803849ba0618be8eee2dd01d0d9e2cd3b75e6ff1a291fa9abcdb34da5e6", size = 1935301, upload-time = "2025-10-07T14:49:01.481Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/46/8f4097b55e43af39e8e71e1f7aec59ff7398bca54d975c30889bc844719d/faker-37.11.0-py3-none-any.whl", hash = "sha256:1508d2da94dfd1e0087b36f386126d84f8583b3de19ac18e392a2831a6676c57", size = 1975525, upload-time = "2025-10-07T14:48:58.29Z" }, +] + [[package]] name = "ghp-import" version = "2.1.0" @@ -856,6 +868,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "polyfactory" +version = "2.22.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "faker" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4e/a6/950d13856d995705df33b92451559fd317207a9c43629ab1771135a0c966/polyfactory-2.22.2.tar.gz", hash = "sha256:a3297aa0b004f2b26341e903795565ae88507c4d86e68b132c2622969028587a", size = 254462, upload-time = "2025-08-15T06:23:21.28Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/fe/d52c90e07c458f38b26f9972a25cb011b2744813f76fcd6121dde64744fa/polyfactory-2.22.2-py3-none-any.whl", hash = "sha256:9bea58ac9a80375b4153cd60820f75e558b863e567e058794d28c6a52b84118a", size = 63715, upload-time = "2025-08-15T06:23:19.664Z" }, +] + [[package]] name = "prek" version = "0.2.8" @@ -1596,6 +1621,7 @@ dev = [ { name = "mkdocs-material" }, { name = "mkdocstrings-python" }, { name = "mypy" }, + { name = "polyfactory" }, { name = "prek" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -1618,6 +1644,7 @@ lint = [ { name = "zizmor" }, ] test = [ + { name = "polyfactory" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, @@ -1643,6 +1670,7 @@ dev = [ { name = "mkdocs-material", specifier = ">=9.6.22" }, { name = "mkdocstrings-python", specifier = ">=1.18.2" }, { name = "mypy", specifier = ">=1.18.1" }, + { name = "polyfactory", specifier = ">=2.22.2" }, { name = "prek", specifier = ">=0.2.8" }, { name = "pytest", specifier = ">=8.4.2" }, { name = "pytest-asyncio", specifier = ">=1.1.0" }, @@ -1665,6 +1693,7 @@ lint = [ { name = "zizmor", specifier = ">=1.15.2" }, ] test = [ + { name = "polyfactory", specifier = ">=2.22.2" }, { name = "pytest", specifier = ">=8.4.2" }, { name = "pytest-asyncio", specifier = ">=1.1.0" }, { name = "pytest-cov", specifier = ">=7.0.0" },