From 3026bede1c5ba5532fe280fbdcf92c8583f565a1 Mon Sep 17 00:00:00 2001 From: Max Chis Date: Fri, 26 Sep 2025 08:15:04 -0400 Subject: [PATCH] Add dependent location logic --- ...18-7b955c783e27_add_dependent_locations.py | 56 +++++++++++++++++++ .../endpoints/search/agency/ctes/__init__.py | 0 .../search/agency/ctes/with_location_id.py | 48 ++++++++++++++++ src/api/endpoints/search/agency/query.py | 25 ++++++--- src/db/models/views/dependent_locations.py | 54 ++++++++++++++++++ .../api/search/agency/test_search.py | 12 +++- 6 files changed, 185 insertions(+), 10 deletions(-) create mode 100644 alembic/versions/2025_09_26_0718-7b955c783e27_add_dependent_locations.py create mode 100644 src/api/endpoints/search/agency/ctes/__init__.py create mode 100644 src/api/endpoints/search/agency/ctes/with_location_id.py create mode 100644 src/db/models/views/dependent_locations.py diff --git a/alembic/versions/2025_09_26_0718-7b955c783e27_add_dependent_locations.py b/alembic/versions/2025_09_26_0718-7b955c783e27_add_dependent_locations.py new file mode 100644 index 00000000..e27633fe --- /dev/null +++ b/alembic/versions/2025_09_26_0718-7b955c783e27_add_dependent_locations.py @@ -0,0 +1,56 @@ +"""Add dependent locations + +Revision ID: 7b955c783e27 +Revises: 3687026267fc +Create Date: 2025-09-26 07:18:37.916841 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '7b955c783e27' +down_revision: Union[str, None] = '3687026267fc' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute(""" + create view dependent_locations(parent_location_id, dependent_location_id) as + SELECT + lp.id AS parent_location_id, + ld.id AS dependent_location_id + FROM + locations lp + JOIN locations ld ON ld.state_id = lp.state_id AND ld.type = 'County'::location_type AND lp.type = 'State'::location_type + UNION ALL + SELECT + lp.id AS parent_location_id, + ld.id AS dependent_location_id + FROM + locations lp + JOIN locations ld ON ld.county_id = lp.county_id AND ld.type = 'Locality'::location_type AND lp.type = 'County'::location_type + UNION ALL + SELECT + lp.id AS parent_location_id, + ld.id AS dependent_location_id + FROM + locations lp + JOIN locations ld ON ld.state_id = lp.state_id AND ld.type = 'Locality'::location_type AND lp.type = 'State'::location_type + UNION ALL + SELECT + lp.id AS parent_location_id, + ld.id AS dependent_location_id + FROM + locations lp + JOIN locations ld ON lp.type = 'National'::location_type AND (ld.type = ANY + (ARRAY ['State'::location_type, 'County'::location_type, 'Locality'::location_type])) + """) + + +def downgrade() -> None: + pass diff --git a/src/api/endpoints/search/agency/ctes/__init__.py b/src/api/endpoints/search/agency/ctes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/api/endpoints/search/agency/ctes/with_location_id.py b/src/api/endpoints/search/agency/ctes/with_location_id.py new file mode 100644 index 00000000..345cb245 --- /dev/null +++ b/src/api/endpoints/search/agency/ctes/with_location_id.py @@ -0,0 +1,48 @@ +from sqlalchemy import select, literal, CTE, Column + +from src.db.models.impl.link.agency_location.sqlalchemy import LinkAgencyLocation +from src.db.models.views.dependent_locations import DependentLocationView + + +class WithLocationIdCTEContainer: + + def __init__(self, location_id: int): + + target_locations_cte = ( + select( + literal(location_id).label("location_id") + ) + .union( + select( + DependentLocationView.dependent_location_id + ) + .where( + DependentLocationView.parent_location_id == location_id + ) + ) + .cte("target_locations") + ) + + self._cte = ( + select( + LinkAgencyLocation.agency_id, + LinkAgencyLocation.location_id + ) + .join( + target_locations_cte, + target_locations_cte.c.location_id == LinkAgencyLocation.location_id + ) + .cte("with_location_id") + ) + + @property + def cte(self) -> CTE: + return self._cte + + @property + def agency_id(self) -> Column: + return self._cte.c.agency_id + + @property + def location_id(self) -> Column: + return self._cte.c.location_id \ No newline at end of file diff --git a/src/api/endpoints/search/agency/query.py b/src/api/endpoints/search/agency/query.py index d3bda3ef..6048468a 100644 --- a/src/api/endpoints/search/agency/query.py +++ b/src/api/endpoints/search/agency/query.py @@ -1,12 +1,14 @@ from typing import Sequence -from sqlalchemy import select, func, RowMapping +from sqlalchemy import select, func, RowMapping, or_ from sqlalchemy.ext.asyncio import AsyncSession +from src.api.endpoints.search.agency.ctes.with_location_id import WithLocationIdCTEContainer from src.api.endpoints.search.agency.models.response import AgencySearchResponse from src.db.helpers.session import session_helper as sh from src.db.models.impl.agency.sqlalchemy import Agency from src.db.models.impl.link.agency_location.sqlalchemy import LinkAgencyLocation +from src.db.models.views.dependent_locations import DependentLocationView from src.db.models.views.location_expanded import LocationExpandedView from src.db.queries.base.builder import QueryBuilderBase @@ -30,20 +32,25 @@ async def run(self, session: AsyncSession) -> list[AgencySearchResponse]: Agency.name.label("agency_name"), LocationExpandedView.display_name.label("location_display_name") ) - .join( + ) + if self.location_id is None: + query = query.join( LinkAgencyLocation, LinkAgencyLocation.agency_id == Agency.agency_id - ) - .join( + ).join( LocationExpandedView, LocationExpandedView.id == LinkAgencyLocation.location_id ) - ) - - if self.location_id is not None: - query = query.where( - LocationExpandedView.id == self.location_id + else: + with_location_id_cte_container = WithLocationIdCTEContainer(self.location_id) + query = query.join( + with_location_id_cte_container.cte, + with_location_id_cte_container.agency_id == Agency.agency_id + ).join( + LocationExpandedView, + LocationExpandedView.id == with_location_id_cte_container.location_id ) + if self.query is not None: query = query.order_by( func.similarity( diff --git a/src/db/models/views/dependent_locations.py b/src/db/models/views/dependent_locations.py new file mode 100644 index 00000000..95f3db98 --- /dev/null +++ b/src/db/models/views/dependent_locations.py @@ -0,0 +1,54 @@ +""" +create view dependent_locations(parent_location_id, dependent_location_id) as +SELECT + lp.id AS parent_location_id, + ld.id AS dependent_location_id +FROM + locations lp + JOIN locations ld ON ld.state_id = lp.state_id AND ld.type = 'County'::location_type AND lp.type = 'State'::location_type +UNION ALL +SELECT + lp.id AS parent_location_id, + ld.id AS dependent_location_id +FROM + locations lp + JOIN locations ld ON ld.county_id = lp.county_id AND ld.type = 'Locality'::location_type AND lp.type = 'County'::location_type +UNION ALL +SELECT + lp.id AS parent_location_id, + ld.id AS dependent_location_id +FROM + locations lp + JOIN locations ld ON ld.state_id = lp.state_id AND ld.type = 'Locality'::location_type AND lp.type = 'State'::location_type +UNION ALL +SELECT + lp.id AS parent_location_id, + ld.id AS dependent_location_id +FROM + locations lp + JOIN locations ld ON lp.type = 'National'::location_type AND (ld.type = ANY + (ARRAY ['State'::location_type, 'County'::location_type, 'Locality'::location_type])); +""" +from sqlalchemy import Column, Integer, ForeignKey + +from src.db.models.mixins import ViewMixin +from src.db.models.templates_.base import Base + + +class DependentLocationView(Base, ViewMixin): + + __tablename__ = "dependent_locations" + __table_args__ = ( + {"info": "view"} + ) + + parent_location_id = Column( + Integer, + ForeignKey("locations.id"), + primary_key=True, + ) + dependent_location_id = Column( + Integer, + ForeignKey("locations.id"), + primary_key=True + ) diff --git a/tests/automated/integration/api/search/agency/test_search.py b/tests/automated/integration/api/search/agency/test_search.py index 7b475ace..cc3fee19 100644 --- a/tests/automated/integration/api/search/agency/test_search.py +++ b/tests/automated/integration/api/search/agency/test_search.py @@ -38,6 +38,7 @@ async def test_search_agency( assert responses[1]["agency_id"] == agency_b_id assert responses[2]["agency_id"] == agency_c_id + # Filter based on location ID responses = api_test_helper.request_validator.get_v2( url="/search/agency", params={ @@ -50,4 +51,13 @@ async def test_search_agency( assert responses[0]["agency_id"] == agency_a_id assert responses[1]["agency_id"] == agency_c_id - + # Filter again based on location ID but with Allegheny County + # Confirm pittsburgh agencies are picked up + responses = api_test_helper.request_validator.get_v2( + url="/search/agency", + params={ + "query": "A Agency", + "location_id": allegheny_county.location_id + } + ) + assert len(responses) == 3