diff --git a/diracx-db/src/diracx/db/sql/auth/schema.py b/diracx-db/src/diracx/db/sql/auth/schema.py index 95a17f49c..6ea0d7d3c 100644 --- a/diracx-db/src/diracx/db/sql/auth/schema.py +++ b/diracx-db/src/diracx/db/sql/auth/schema.py @@ -8,13 +8,15 @@ String, Uuid, ) -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from diracx.db.sql.utils import Column, DateNowColumn, EnumColumn, NullColumn USER_CODE_LENGTH = 8 -Base = declarative_base() + +class Base(DeclarativeBase): + pass class FlowStatus(Enum): diff --git a/diracx-db/src/diracx/db/sql/dummy/schema.py b/diracx-db/src/diracx/db/sql/dummy/schema.py index 5379de94d..33debcb89 100644 --- a/diracx-db/src/diracx/db/sql/dummy/schema.py +++ b/diracx-db/src/diracx/db/sql/dummy/schema.py @@ -3,11 +3,13 @@ from __future__ import annotations from sqlalchemy import ForeignKey, Integer, String, Uuid -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from diracx.db.sql.utils import Column, DateNowColumn -Base = declarative_base() + +class Base(DeclarativeBase): + pass class Owners(Base): diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 01cdb83a1..d260f5a40 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -3,12 +3,13 @@ __all__ = ["JobDB"] from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING, Any, Iterable, cast from sqlalchemy import bindparam, case, delete, insert, select, update if TYPE_CHECKING: from sqlalchemy.sql.elements import BindParameter + from sqlalchemy import Table from diracx.core.exceptions import InvalidQueryError from diracx.core.models import JobCommand, SearchSpec, SortSpec @@ -73,7 +74,7 @@ async def search( async def create_job(self, compressed_original_jdl: str): """Used to insert a new job with original JDL. Returns inserted job id.""" result = await self.conn.execute( - JobJDLs.__table__.insert().values( + cast("Table", JobJDLs.__table__).insert().values( JDL="", JobRequirements="", OriginalJDL=compressed_original_jdl, @@ -89,7 +90,7 @@ async def delete_jobs(self, job_ids: list[int]): async def insert_input_data(self, lfns: dict[int, list[str]]): """Insert input data for jobs.""" await self.conn.execute( - InputData.__table__.insert(), + cast("Table", InputData.__table__).insert(), [ { "JobID": job_id, @@ -103,7 +104,7 @@ async def insert_input_data(self, lfns: dict[int, list[str]]): async def insert_job_attributes(self, jobs_to_update: dict[int, dict]): """Insert the job attributes.""" await self.conn.execute( - Jobs.__table__.insert(), + cast("Table", Jobs.__table__).insert(), [ { "JobID": job_id, @@ -116,7 +117,7 @@ async def insert_job_attributes(self, jobs_to_update: dict[int, dict]): async def update_job_jdls(self, jdls_to_update: dict[int, str]): """Used to update the JDL, typically just after inserting the original JDL, or rescheduling, for example.""" await self.conn.execute( - JobJDLs.__table__.update().where( + cast("Table", JobJDLs.__table__).update().where( JobJDLs.__table__.c.JobID == bindparam("b_JobID") ), [ @@ -171,7 +172,7 @@ async def set_job_attributes(self, job_data): } stmt = ( - Jobs.__table__.update() + cast("Table", Jobs.__table__).update() .values(**case_expressions) .where(Jobs.__table__.c.JobID.in_(job_data.keys())) ) @@ -228,7 +229,7 @@ async def set_properties( required_parameters = list(required_parameters_set)[0] update_parameters = [{"job_id": k, **v} for k, v in properties.items()] - columns = _get_columns(Jobs.__table__, required_parameters) + columns = _get_columns(cast("Table", Jobs.__table__), list(required_parameters)) values: dict[str, BindParameter[Any] | datetime] = { c.name: bindparam(c.name) for c in columns } @@ -267,7 +268,7 @@ async def add_heartbeat_data( } for key, value in dynamic_data.items() ] - await self.conn.execute(HeartBeatLoggingInfo.__table__.insert().values(values)) + await self.conn.execute(cast("Table", HeartBeatLoggingInfo.__table__).insert().values(values)) async def get_job_commands(self, job_ids: Iterable[int]) -> list[JobCommand]: """Get a command to be passed to the job together with the next heartbeat. diff --git a/diracx-db/src/diracx/db/sql/job/schema.py b/diracx-db/src/diracx/db/sql/job/schema.py index bb9f60bf1..9d6e1003b 100644 --- a/diracx-db/src/diracx/db/sql/job/schema.py +++ b/diracx-db/src/diracx/db/sql/job/schema.py @@ -9,11 +9,13 @@ String, Text, ) -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from ..utils import Column, EnumBackedBool, NullColumn -JobDBBase = declarative_base() + +class JobDBBase(DeclarativeBase): + pass class AccountedFlagEnum(types.TypeDecorator): diff --git a/diracx-db/src/diracx/db/sql/job_logging/db.py b/diracx-db/src/diracx/db/sql/job_logging/db.py index f225ed95c..fd158af27 100644 --- a/diracx-db/src/diracx/db/sql/job_logging/db.py +++ b/diracx-db/src/diracx/db/sql/job_logging/db.py @@ -2,10 +2,13 @@ from collections import defaultdict from datetime import datetime, timezone -from typing import Iterable +from typing import Iterable, cast, TYPE_CHECKING from sqlalchemy import delete, func, select +if TYPE_CHECKING: + from sqlalchemy import Table + from diracx.core.models import JobLoggingRecord, JobStatusReturn from ..utils import BaseSQLDB @@ -56,7 +59,7 @@ async def insert_records( seqnums[record.job_id] = seqnums[record.job_id] + 1 await self.conn.execute( - LoggingInfo.__table__.insert(), + cast("Table", LoggingInfo.__table__).insert(), values, ) diff --git a/diracx-db/src/diracx/db/sql/job_logging/schema.py b/diracx-db/src/diracx/db/sql/job_logging/schema.py index 2366448f2..e438b162b 100644 --- a/diracx-db/src/diracx/db/sql/job_logging/schema.py +++ b/diracx-db/src/diracx/db/sql/job_logging/schema.py @@ -3,11 +3,13 @@ from datetime import UTC, datetime from sqlalchemy import Integer, Numeric, PrimaryKeyConstraint, String, TypeDecorator -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from ..utils import Column, DateNowColumn -JobLoggingDBBase = declarative_base() + +class JobLoggingDBBase(DeclarativeBase): + pass class MagicEpochDateTime(TypeDecorator): diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py index bff7c460c..42a6dbf8b 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py @@ -8,11 +8,13 @@ String, Text, ) -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from ..utils import Column, EnumBackedBool, NullColumn -PilotAgentsDBBase = declarative_base() + +class PilotAgentsDBBase(DeclarativeBase): + pass class PilotAgents(PilotAgentsDBBase): diff --git a/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py b/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py index 4cf9a2a7d..049dd7f7c 100644 --- a/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py +++ b/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py @@ -9,11 +9,13 @@ String, UniqueConstraint, ) -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from diracx.db.sql.utils import Column, DateNowColumn -Base = declarative_base() + +class Base(DeclarativeBase): + pass class SBOwners(Base): diff --git a/diracx-db/src/diracx/db/sql/task_queue/schema.py b/diracx-db/src/diracx/db/sql/task_queue/schema.py index 0a3c0f033..88ffdfdc9 100644 --- a/diracx-db/src/diracx/db/sql/task_queue/schema.py +++ b/diracx-db/src/diracx/db/sql/task_queue/schema.py @@ -9,11 +9,13 @@ Integer, String, ) -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from ..utils import Column -TaskQueueDBBase = declarative_base() + +class TaskQueueDBBase(DeclarativeBase): + pass class TaskQueues(TaskQueueDBBase): diff --git a/diracx-db/src/diracx/db/sql/utils/base.py b/diracx-db/src/diracx/db/sql/utils/base.py index 9b349e14e..52107e85a 100644 --- a/diracx-db/src/diracx/db/sql/utils/base.py +++ b/diracx-db/src/diracx/db/sql/utils/base.py @@ -8,13 +8,16 @@ from collections.abc import AsyncIterator from contextvars import ContextVar from datetime import datetime -from typing import Any, Self, cast +from typing import Any, Self, cast, TYPE_CHECKING from pydantic import TypeAdapter -from sqlalchemy import DateTime, MetaData, func, select +from sqlalchemy import DateTime, MetaData, Table, func, select from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine +if TYPE_CHECKING: + from sqlalchemy.orm import DeclarativeBase + from diracx.core.exceptions import InvalidQueryError from diracx.core.extensions import select_from_extension from diracx.core.models import ( @@ -233,7 +236,7 @@ async def ping(self): async def _search( self, - table: Any, + table: type[DeclarativeBase], parameters: list[str] | None, search: list[SearchSpec], sorts: list[SortSpec], @@ -244,12 +247,13 @@ async def _search( ) -> tuple[int, list[dict[str, Any]]]: """Search for elements in a table.""" # Find which columns to select - columns = _get_columns(table.__table__, parameters) + table_obj = cast(Table, table.__table__) + columns = _get_columns(table_obj, parameters) stmt = select(*columns) - stmt = apply_search_filters(table.__table__.columns.__getitem__, stmt, search) - stmt = apply_sort_constraints(table.__table__.columns.__getitem__, stmt, sorts) + stmt = apply_search_filters(table_obj.columns.__getitem__, stmt, search) + stmt = apply_sort_constraints(table_obj.columns.__getitem__, stmt, sorts) if distinct: stmt = stmt.distinct() @@ -273,12 +277,13 @@ async def _search( ] async def _summary( - self, table: Any, group_by: list[str], search: list[SearchSpec] + self, table: type[DeclarativeBase], group_by: list[str], search: list[SearchSpec] ) -> list[dict[str, str | int]]: """Get a summary of the elements of a table.""" - columns = _get_columns(table.__table__, group_by) + table_obj = cast(Table, table.__table__) + columns = _get_columns(table_obj, group_by) - pk_columns = list(table.__table__.primary_key.columns) + pk_columns = list(table_obj.primary_key.columns) if not pk_columns: raise ValueError( "Model has no primary key and no count_column was provided." @@ -286,7 +291,7 @@ async def _summary( count_col = pk_columns[0] stmt = select(*columns, func.count(count_col).label("count")) - stmt = apply_search_filters(table.__table__.columns.__getitem__, stmt, search) + stmt = apply_search_filters(table_obj.columns.__getitem__, stmt, search) stmt = stmt.group_by(*columns) # Execute the query @@ -327,7 +332,7 @@ def find_time_resolution(value): raise InvalidQueryError(f"Cannot parse {value=}") -def _get_columns(table, parameters): +def _get_columns(table: Table, parameters: list[str] | None): columns = [x for x in table.columns] if parameters: if unrecognised_parameters := set(parameters) - set(table.columns.keys()): diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/schema.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/schema.py index eee922d4e..10869665e 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/schema.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/schema.py @@ -1,4 +1,4 @@ -from diracx.db.sql.job.db import JobDBBase +from diracx.db.sql.job.schema import JobDBBase from diracx.db.sql.utils import Column from sqlalchemy import ( ForeignKey, diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py index 9b80e5133..5e104956e 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py @@ -2,9 +2,11 @@ # in place of the SQLAlchemy one. Have a look at them from diracx.db.sql.utils import Column, DateNowColumn from sqlalchemy import ForeignKey, Integer, String, Uuid -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase -Base = declarative_base() + +class Base(DeclarativeBase): + pass class Owners(Base):