Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions diracx-db/src/diracx/db/sql/auth/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
String,
Uuid,
)
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import DeclarativeBase

from diracx.db.sql.utils import Column, DateNowColumn, EnumColumn, NullColumn

USER_CODE_LENGTH = 8

Base = declarative_base()

class Base(DeclarativeBase):
pass


class FlowStatus(Enum):
Expand Down
6 changes: 4 additions & 2 deletions diracx-db/src/diracx/db/sql/dummy/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from __future__ import annotations

from sqlalchemy import ForeignKey, Integer, String, Uuid
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import DeclarativeBase

from diracx.db.sql.utils import Column, DateNowColumn

Base = declarative_base()

class Base(DeclarativeBase):
pass


class Owners(Base):
Expand Down
17 changes: 9 additions & 8 deletions diracx-db/src/diracx/db/sql/job/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
__all__ = ["JobDB"]

from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Iterable
from typing import TYPE_CHECKING, Any, Iterable, cast

from sqlalchemy import bindparam, case, delete, insert, select, update

if TYPE_CHECKING:
from sqlalchemy.sql.elements import BindParameter
from sqlalchemy import Table

from diracx.core.exceptions import InvalidQueryError
from diracx.core.models import JobCommand, SearchSpec, SortSpec
Expand Down Expand Up @@ -73,7 +74,7 @@ async def search(
async def create_job(self, compressed_original_jdl: str):
"""Used to insert a new job with original JDL. Returns inserted job id."""
result = await self.conn.execute(
JobJDLs.__table__.insert().values(
cast("Table", JobJDLs.__table__).insert().values(
JDL="",
JobRequirements="",
OriginalJDL=compressed_original_jdl,
Expand All @@ -89,7 +90,7 @@ async def delete_jobs(self, job_ids: list[int]):
async def insert_input_data(self, lfns: dict[int, list[str]]):
"""Insert input data for jobs."""
await self.conn.execute(
InputData.__table__.insert(),
cast("Table", InputData.__table__).insert(),
[
{
"JobID": job_id,
Expand All @@ -103,7 +104,7 @@ async def insert_input_data(self, lfns: dict[int, list[str]]):
async def insert_job_attributes(self, jobs_to_update: dict[int, dict]):
"""Insert the job attributes."""
await self.conn.execute(
Jobs.__table__.insert(),
cast("Table", Jobs.__table__).insert(),
[
{
"JobID": job_id,
Expand All @@ -116,7 +117,7 @@ async def insert_job_attributes(self, jobs_to_update: dict[int, dict]):
async def update_job_jdls(self, jdls_to_update: dict[int, str]):
"""Used to update the JDL, typically just after inserting the original JDL, or rescheduling, for example."""
await self.conn.execute(
JobJDLs.__table__.update().where(
cast("Table", JobJDLs.__table__).update().where(
JobJDLs.__table__.c.JobID == bindparam("b_JobID")
),
[
Expand Down Expand Up @@ -171,7 +172,7 @@ async def set_job_attributes(self, job_data):
}

stmt = (
Jobs.__table__.update()
cast("Table", Jobs.__table__).update()
.values(**case_expressions)
.where(Jobs.__table__.c.JobID.in_(job_data.keys()))
)
Expand Down Expand Up @@ -228,7 +229,7 @@ async def set_properties(
required_parameters = list(required_parameters_set)[0]
update_parameters = [{"job_id": k, **v} for k, v in properties.items()]

columns = _get_columns(Jobs.__table__, required_parameters)
columns = _get_columns(cast("Table", Jobs.__table__), list(required_parameters))
values: dict[str, BindParameter[Any] | datetime] = {
c.name: bindparam(c.name) for c in columns
}
Expand Down Expand Up @@ -267,7 +268,7 @@ async def add_heartbeat_data(
}
for key, value in dynamic_data.items()
]
await self.conn.execute(HeartBeatLoggingInfo.__table__.insert().values(values))
await self.conn.execute(cast("Table", HeartBeatLoggingInfo.__table__).insert().values(values))

async def get_job_commands(self, job_ids: Iterable[int]) -> list[JobCommand]:
"""Get a command to be passed to the job together with the next heartbeat.
Expand Down
6 changes: 4 additions & 2 deletions diracx-db/src/diracx/db/sql/job/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
String,
Text,
)
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import DeclarativeBase

from ..utils import Column, EnumBackedBool, NullColumn

JobDBBase = declarative_base()

class JobDBBase(DeclarativeBase):
pass


class AccountedFlagEnum(types.TypeDecorator):
Expand Down
7 changes: 5 additions & 2 deletions diracx-db/src/diracx/db/sql/job_logging/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

from collections import defaultdict
from datetime import datetime, timezone
from typing import Iterable
from typing import Iterable, cast, TYPE_CHECKING

from sqlalchemy import delete, func, select

if TYPE_CHECKING:
from sqlalchemy import Table

from diracx.core.models import JobLoggingRecord, JobStatusReturn

from ..utils import BaseSQLDB
Expand Down Expand Up @@ -56,7 +59,7 @@ async def insert_records(
seqnums[record.job_id] = seqnums[record.job_id] + 1

await self.conn.execute(
LoggingInfo.__table__.insert(),
cast("Table", LoggingInfo.__table__).insert(),
values,
)

Expand Down
6 changes: 4 additions & 2 deletions diracx-db/src/diracx/db/sql/job_logging/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from datetime import UTC, datetime

from sqlalchemy import Integer, Numeric, PrimaryKeyConstraint, String, TypeDecorator
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import DeclarativeBase

from ..utils import Column, DateNowColumn

JobLoggingDBBase = declarative_base()

class JobLoggingDBBase(DeclarativeBase):
pass


class MagicEpochDateTime(TypeDecorator):
Expand Down
6 changes: 4 additions & 2 deletions diracx-db/src/diracx/db/sql/pilot_agents/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
String,
Text,
)
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import DeclarativeBase

from ..utils import Column, EnumBackedBool, NullColumn

PilotAgentsDBBase = declarative_base()

class PilotAgentsDBBase(DeclarativeBase):
pass


class PilotAgents(PilotAgentsDBBase):
Expand Down
6 changes: 4 additions & 2 deletions diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
String,
UniqueConstraint,
)
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import DeclarativeBase

from diracx.db.sql.utils import Column, DateNowColumn

Base = declarative_base()

class Base(DeclarativeBase):
pass


class SBOwners(Base):
Expand Down
6 changes: 4 additions & 2 deletions diracx-db/src/diracx/db/sql/task_queue/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
Integer,
String,
)
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import DeclarativeBase

from ..utils import Column

TaskQueueDBBase = declarative_base()

class TaskQueueDBBase(DeclarativeBase):
pass


class TaskQueues(TaskQueueDBBase):
Expand Down
27 changes: 16 additions & 11 deletions diracx-db/src/diracx/db/sql/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
from collections.abc import AsyncIterator
from contextvars import ContextVar
from datetime import datetime
from typing import Any, Self, cast
from typing import Any, Self, cast, TYPE_CHECKING

from pydantic import TypeAdapter
from sqlalchemy import DateTime, MetaData, func, select
from sqlalchemy import DateTime, MetaData, Table, func, select
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine

if TYPE_CHECKING:
from sqlalchemy.orm import DeclarativeBase

from diracx.core.exceptions import InvalidQueryError
from diracx.core.extensions import select_from_extension
from diracx.core.models import (
Expand Down Expand Up @@ -233,7 +236,7 @@ async def ping(self):

async def _search(
self,
table: Any,
table: type[DeclarativeBase],
parameters: list[str] | None,
search: list[SearchSpec],
sorts: list[SortSpec],
Expand All @@ -244,12 +247,13 @@ async def _search(
) -> tuple[int, list[dict[str, Any]]]:
"""Search for elements in a table."""
# Find which columns to select
columns = _get_columns(table.__table__, parameters)
table_obj = cast(Table, table.__table__)
columns = _get_columns(table_obj, parameters)

stmt = select(*columns)

stmt = apply_search_filters(table.__table__.columns.__getitem__, stmt, search)
stmt = apply_sort_constraints(table.__table__.columns.__getitem__, stmt, sorts)
stmt = apply_search_filters(table_obj.columns.__getitem__, stmt, search)
stmt = apply_sort_constraints(table_obj.columns.__getitem__, stmt, sorts)

if distinct:
stmt = stmt.distinct()
Expand All @@ -273,20 +277,21 @@ async def _search(
]

async def _summary(
self, table: Any, group_by: list[str], search: list[SearchSpec]
self, table: type[DeclarativeBase], group_by: list[str], search: list[SearchSpec]
) -> list[dict[str, str | int]]:
"""Get a summary of the elements of a table."""
columns = _get_columns(table.__table__, group_by)
table_obj = cast(Table, table.__table__)
columns = _get_columns(table_obj, group_by)

pk_columns = list(table.__table__.primary_key.columns)
pk_columns = list(table_obj.primary_key.columns)
if not pk_columns:
raise ValueError(
"Model has no primary key and no count_column was provided."
)
count_col = pk_columns[0]

stmt = select(*columns, func.count(count_col).label("count"))
stmt = apply_search_filters(table.__table__.columns.__getitem__, stmt, search)
stmt = apply_search_filters(table_obj.columns.__getitem__, stmt, search)
stmt = stmt.group_by(*columns)

# Execute the query
Expand Down Expand Up @@ -327,7 +332,7 @@ def find_time_resolution(value):
raise InvalidQueryError(f"Cannot parse {value=}")


def _get_columns(table, parameters):
def _get_columns(table: Table, parameters: list[str] | None):
columns = [x for x in table.columns]
if parameters:
if unrecognised_parameters := set(parameters) - set(table.columns.keys()):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from diracx.db.sql.job.db import JobDBBase
from diracx.db.sql.job.schema import JobDBBase
from diracx.db.sql.utils import Column
from sqlalchemy import (
ForeignKey,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
# in place of the SQLAlchemy one. Have a look at them
from diracx.db.sql.utils import Column, DateNowColumn
from sqlalchemy import ForeignKey, Integer, String, Uuid
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import DeclarativeBase

Base = declarative_base()

class Base(DeclarativeBase):
pass


class Owners(Base):
Expand Down
Loading