Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/api/endpoints/search/agency/query.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Sequence

from sqlalchemy import select, func, RowMapping, or_
from sqlalchemy import select, func, RowMapping
from sqlalchemy.ext.asyncio import AsyncSession

from src.api.endpoints.search.agency.ctes.with_location_id import WithLocationIdCTEContainer
from src.api.endpoints.search.agency.models.response import AgencySearchResponse
from src.db.helpers.session import session_helper as sh
from src.db.models.impl.agency.enums import JurisdictionType
from src.db.models.impl.agency.sqlalchemy import Agency
from src.db.models.impl.link.agency_location.sqlalchemy import LinkAgencyLocation
from src.db.models.views.dependent_locations import DependentLocationView
from src.db.models.views.location_expanded import LocationExpandedView
from src.db.queries.base.builder import QueryBuilderBase

Expand All @@ -18,11 +18,13 @@ class SearchAgencyQueryBuilder(QueryBuilderBase):
def __init__(
self,
location_id: int | None,
query: str | None
query: str | None,
jurisdiction_type: JurisdictionType | None,
):
super().__init__()
self.location_id = location_id
self.query = query
self.jurisdiction_type = jurisdiction_type

async def run(self, session: AsyncSession) -> list[AgencySearchResponse]:

Expand Down Expand Up @@ -51,6 +53,11 @@ async def run(self, session: AsyncSession) -> list[AgencySearchResponse]:
LocationExpandedView.id == with_location_id_cte_container.location_id
)

if self.jurisdiction_type is not None:
query = query.where(
Agency.jurisdiction_type == self.jurisdiction_type
)

if self.query is not None:
query = query.order_by(
func.similarity(
Expand Down
8 changes: 7 additions & 1 deletion src/api/endpoints/search/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from src.api.endpoints.search.agency.query import SearchAgencyQueryBuilder
from src.api.endpoints.search.dtos.response import SearchURLResponse
from src.core.core import AsyncCore
from src.db.models.impl.agency.enums import JurisdictionType
from src.security.manager import get_access_info
from src.security.dtos.access_info import AccessInfo

Expand Down Expand Up @@ -35,6 +36,10 @@ async def search_agency(
description="The query to search for",
default=None
),
jurisdiction_type: JurisdictionType | None = Query(
description="The jurisdiction type to search for",
default=None
),
access_info: AccessInfo = Depends(get_access_info),
async_core: AsyncCore = Depends(get_async_core),
) -> list[AgencySearchResponse]:
Expand All @@ -47,6 +52,7 @@ async def search_agency(
return await async_core.adb_client.run_query_builder(
SearchAgencyQueryBuilder(
location_id=location_id,
query=query
query=query,
jurisdiction_type=jurisdiction_type
)
)