diff --git a/alembic/env.py b/alembic/env.py index 9cdc453d..fc1c5e0c 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -29,9 +29,13 @@ DB_DATABASE_NAME = os.getenv("DB_DATABASE_NAME") DB_USERNAME = os.getenv("DB_USERNAME") DB_PASSWORD = os.getenv("DB_PASSWORD") -config.set_section_option( - "alembic", "sqlalchemy.url", f"postgresql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE_NAME}" -) + +if not DB_USERNAME: + DB_URL = f"postgresql:///{DB_DATABASE_NAME}" +else: + DB_URL = f"postgresql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE_NAME}" + +config.set_section_option("alembic", "sqlalchemy.url", DB_URL) target_metadata = Base.metadata diff --git a/alembic/versions/0d3732aa62be_add_scoreset_search_materialized_view.py b/alembic/versions/0d3732aa62be_add_scoreset_search_materialized_view.py new file mode 100644 index 00000000..092c6cb6 --- /dev/null +++ b/alembic/versions/0d3732aa62be_add_scoreset_search_materialized_view.py @@ -0,0 +1,26 @@ +"""add scoreset_fulltext materialized view + +Revision ID: 0d3732aa62be +Revises: ec5d2787bec9 +Create Date: 2024-10-15 14:59:16.297975 + +""" +from alembic import op + +from mavedb.models.score_set_fulltext import _scoreset_fulltext_view + +# revision identifiers, used by Alembic. +revision = '0d3732aa62be' +down_revision = '1d4933b4b6f7' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_entity(_scoreset_fulltext_view) + op.execute("create index scoreset_fulltext_idx on scoreset_fulltext using gin (text)") + + +def downgrade(): + op.execute("drop index scoreset_fulltext_idx") + op.drop_entity(_scoreset_fulltext_view) diff --git a/poetry.lock b/poetry.lock index 38fcd06a..75b643f4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "alembic" @@ -18,6 +18,28 @@ SQLAlchemy = ">=1.3.0" [package.extras] tz = ["python-dateutil"] +[[package]] +name = "alembic-utils" +version = "0.8.1" +description = "A sqlalchemy/alembic extension for migrating procedures and views" +optional = true +python-versions = ">=3.7" +files = [ + {file = "alembic_utils-0.8.1.tar.gz", hash = "sha256:073626217c8d8bdc66d1f66f8866d4f743969ac08502ba3bc15bcd60190460d7"}, +] + +[package.dependencies] +alembic = ">=1.5.7" +flupy = "*" +parse = ">=1.8.4" +sqlalchemy = ">=1.4" +typing_extensions = "*" + +[package.extras] +dev = ["black", "mkdocs", "mypy", "pre-commit", "psycopg2-binary", "pylint", "pytest", "pytest-cov"] +docs = ["mkautodoc", "mkdocs", "pygments", "pymdown-extensions"] +nvim = ["neovim", "python-language-server"] + [[package]] name = "anyio" version = "4.6.0" @@ -1296,6 +1318,22 @@ docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2. testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] typing = ["typing-extensions (>=4.12.2)"] +[[package]] +name = "flupy" +version = "1.2.1" +description = "Method chaining built on generators" +optional = true +python-versions = "*" +files = [ + {file = "flupy-1.2.1.tar.gz", hash = "sha256:42aab3b4b3eb1984a4616c40d8f049ecdee546eaad9467470731d456dbff7fa4"}, +] + +[package.dependencies] +typing_extensions = ">=4" + +[package.extras] +dev = ["black", "mypy", "pre-commit", "pylint", "pytest", "pytest-benchmark", "pytest-cov"] + [[package]] name = "fqfa" version = "1.3.1" @@ -2371,6 +2409,17 @@ files = [ numpy = {version = ">=1.26.0", markers = "python_version < \"3.13\""} types-pytz = ">=2022.1.1" +[[package]] +name = "parse" +version = "1.20.2" +description = "parse() is the opposite of format()" +optional = true +python-versions = "*" +files = [ + {file = "parse-1.20.2-py2.py3-none-any.whl", hash = "sha256:967095588cb802add9177d0c0b6133b5ba33b1ea9007ca800e526f42a85af558"}, + {file = "parse-1.20.2.tar.gz", hash = "sha256:b41d604d16503c79d81af5165155c0b20f6c8d6c559efa66b4b695c3e5a0a0ce"}, +] + [[package]] name = "parsley" version = "1.3" @@ -4115,9 +4164,9 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", type = ["pytest-mypy"] [extras] -server = ["alembic", "arq", "authlib", "biocommons", "boto3", "cdot", "cryptography", "fastapi", "hgvs", "orcid", "psycopg2", "python-jose", "python-multipart", "requests", "slack-sdk", "starlette", "starlette-context", "uvicorn", "watchtower"] +server = ["alembic", "alembic-utils", "arq", "authlib", "biocommons", "boto3", "cdot", "cryptography", "fastapi", "hgvs", "orcid", "psycopg2", "python-jose", "python-multipart", "requests", "slack-sdk", "starlette", "starlette-context", "uvicorn", "watchtower"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "f5a4cedf018200abbbb7eebf9d2a51110454c5dac959d3ab0601bc185e2a351c" +content-hash = "683c9fb24adca5ab47f453e174aea72c8a9cec0a7672ac97d8cc4b94a107deee" diff --git a/pyproject.toml b/pyproject.toml index 98e8a828..5158fe90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ SQLAlchemy = "~2.0.0" # Optional dependencies for running this application as a server alembic = { version = "~1.7.6", optional = true } +alembic-utils = { version = "0.8.1", optional = true } arq = { version = "~0.25.0", optional = true } authlib = { version = "~1.3.1", optional = true } boto3 = { version = "~1.34.97", optional = true } @@ -83,10 +84,12 @@ requests-mock = "~1.11.0" ruff = "^0.6.8" SQLAlchemy = { extras = ["mypy"], version = "~2.0.0" } - [tool.poetry.extras] -server = ["alembic", "arq", "authlib", "biocommons", "boto3", "cdot", "cryptography", "fastapi", "hgvs", "orcid", "psycopg2", "python-jose", "python-multipart", "requests", "starlette", "starlette-context", "slack-sdk", "uvicorn", "watchtower"] +server = ["alembic", "alembic-utils", "arq", "authlib", "biocommons", "boto3", "cdot", "cryptography", "fastapi", "hgvs", "orcid", "psycopg2", "python-jose", "python-multipart", "requests", "starlette", "starlette-context", "slack-sdk", "uvicorn", "watchtower"] +[tool.black] +extend-exclude = "alembic/versions" +line-length = 120 [tool.mypy] plugins = [ diff --git a/src/mavedb/db/session.py b/src/mavedb/db/session.py index ab75604a..c313d99f 100644 --- a/src/mavedb/db/session.py +++ b/src/mavedb/db/session.py @@ -11,7 +11,10 @@ DB_PASSWORD = os.getenv("DB_PASSWORD") # DB_URL = "sqlite:///./sql_app.db" -DB_URL = f"postgresql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE_NAME}" +if not DB_USERNAME: + DB_URL = f"postgresql:///{DB_DATABASE_NAME}" +else: + DB_URL = f"postgresql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE_NAME}" engine = create_engine( # For PostgreSQL: diff --git a/src/mavedb/lib/score_sets.py b/src/mavedb/lib/score_sets.py index 8384ecbf..b3f9cd5d 100644 --- a/src/mavedb/lib/score_sets.py +++ b/src/mavedb/lib/score_sets.py @@ -23,25 +23,22 @@ from mavedb.lib.validation.constants.general import null_values_list from mavedb.models.contributor import Contributor from mavedb.models.controlled_keyword import ControlledKeyword -from mavedb.models.doi_identifier import DoiIdentifier -from mavedb.models.ensembl_identifier import EnsemblIdentifier from mavedb.models.ensembl_offset import EnsemblOffset from mavedb.models.experiment import Experiment from mavedb.models.experiment_controlled_keyword import ExperimentControlledKeywordAssociation from mavedb.models.experiment_publication_identifier import ExperimentPublicationIdentifierAssociation from mavedb.models.experiment_set import ExperimentSet from mavedb.models.publication_identifier import PublicationIdentifier -from mavedb.models.refseq_identifier import RefseqIdentifier from mavedb.models.refseq_offset import RefseqOffset from mavedb.models.score_set import ScoreSet from mavedb.models.score_set_publication_identifier import ( ScoreSetPublicationIdentifierAssociation, ) +from mavedb.models.score_set_fulltext import scoreset_fulltext_filter from mavedb.models.target_accession import TargetAccession from mavedb.models.target_gene import TargetGene from mavedb.models.target_sequence import TargetSequence from mavedb.models.taxonomy import Taxonomy -from mavedb.models.uniprot_identifier import UniprotIdentifier from mavedb.models.uniprot_offset import UniprotOffset from mavedb.models.user import User from mavedb.models.variant import Variant @@ -86,74 +83,7 @@ def search_score_sets(db: Session, owner_or_contributor: Optional[User], search: query = query.filter(ScoreSet.published_date.is_(None)) if search.text: - lower_search_text = search.text.lower().strip() - query = query.filter( - or_( - ScoreSet.urn.icontains(lower_search_text), - ScoreSet.title.icontains(lower_search_text), - ScoreSet.short_description.icontains(lower_search_text), - ScoreSet.abstract_text.icontains(lower_search_text), - ScoreSet.target_genes.any(func.lower(TargetGene.name).icontains(lower_search_text)), - ScoreSet.target_genes.any(func.lower(TargetGene.category).icontains(lower_search_text)), - ScoreSet.target_genes.any( - TargetGene.target_sequence.has( - TargetSequence.taxonomy.has(func.lower(Taxonomy.organism_name).icontains(lower_search_text)) - ) - ), - ScoreSet.target_genes.any( - TargetGene.target_sequence.has( - TargetSequence.taxonomy.has(func.lower(Taxonomy.common_name).icontains(lower_search_text)) - ) - ), - ScoreSet.target_genes.any( - TargetGene.target_accession.has(func.lower(TargetAccession.assembly).icontains(lower_search_text)) - ), - # TODO(#94): add LICENSE, plus TAX_ID if numeric - ScoreSet.publication_identifiers.any( - func.lower(PublicationIdentifier.identifier).icontains(lower_search_text) - ), - ScoreSet.publication_identifiers.any( - func.lower(PublicationIdentifier.doi).icontains(lower_search_text) - ), - ScoreSet.publication_identifiers.any( - func.lower(PublicationIdentifier.abstract).icontains(lower_search_text) - ), - ScoreSet.publication_identifiers.any( - func.lower(PublicationIdentifier.title).icontains(lower_search_text) - ), - ScoreSet.publication_identifiers.any( - func.lower(PublicationIdentifier.publication_journal).icontains(lower_search_text) - ), - ScoreSet.publication_identifiers.any( - func.jsonb_path_exists( - PublicationIdentifier.authors, - f"""$[*].name ? (@ like_regex "{lower_search_text}" flag "i")""", - ) - ), - ScoreSet.doi_identifiers.any(func.lower(DoiIdentifier.identifier).icontains(lower_search_text)), - ScoreSet.target_genes.any( - TargetGene.uniprot_offset.has( - UniprotOffset.identifier.has( - func.lower(UniprotIdentifier.identifier).icontains(lower_search_text) - ) - ) - ), - ScoreSet.target_genes.any( - TargetGene.refseq_offset.has( - RefseqOffset.identifier.has( - func.lower(RefseqIdentifier.identifier).icontains(lower_search_text) - ) - ) - ), - ScoreSet.target_genes.any( - TargetGene.ensembl_offset.has( - EnsemblOffset.identifier.has( - func.lower(EnsemblIdentifier.identifier).icontains(lower_search_text) - ) - ) - ), - ) - ) + query = scoreset_fulltext_filter(query, search.text) if search.targets: query = query.filter(ScoreSet.target_genes.any(TargetGene.name.in_(search.targets))) diff --git a/src/mavedb/models/score_set_fulltext.py b/src/mavedb/models/score_set_fulltext.py new file mode 100644 index 00000000..06169ab8 --- /dev/null +++ b/src/mavedb/models/score_set_fulltext.py @@ -0,0 +1,78 @@ +import logging + +from sqlalchemy import text +from mavedb.models.score_set import ScoreSet +from alembic_utils.pg_materialized_view import PGMaterializedView # type: ignore + +logger = logging.getLogger(__name__) + +# TODO(#94): add LICENSE, plus TAX_ID if numeric +# TODO(#89): The query below should be generated from SQLAlchemy +# models rather than hand-carved SQL + +_scoreset_fulltext_view = PGMaterializedView( + schema="public", + signature="scoreset_fulltext", + definition=' union ' .join( + [ + f"select id as scoreset_id, to_tsvector({c}) as text from scoresets" + for c in ('urn', 'title', 'short_description', 'abstract_text') + ] + [ + f"select scoreset_id, to_tsvector({c}) as text from target_genes" + for c in ('name', 'category') + ] + [ + f"select scoreset_id, to_tsvector(TX.{c}) as text from target_genes TG join target_sequences TS on \ + (TG.target_sequence_id = TS.id) join taxonomies TX on (TS.taxonomy_id = TX.id)" + for c in ('organism_name', 'common_name') + ] + [ + "select scoreset_id, to_tsvector(TA.assembly) as text from target_genes TG join target_accessions TA on \ + (TG.accession_id = TA.id)" + ] + [ + f"select scoreset_id, to_tsvector(PI.{c}) as text from scoreset_publication_identifiers SPI JOIN \ + publication_identifiers PI ON (SPI.publication_identifier_id = PI.id)" + for c in ('identifier', 'doi', 'abstract', 'title', 'publication_journal') + ] + [ + "select scoreset_id, to_tsvector(jsonb_array_elements(authors)->'name') as text from \ + scoreset_publication_identifiers SPI join publication_identifiers PI on \ + SPI.publication_identifier_id = PI.id", + ] + [ + "select scoreset_id, to_tsvector(DI.identifier) as text from scoreset_doi_identifiers SD join \ + doi_identifiers DI on (SD.doi_identifier_id = DI.id)", + ] + [ + f"select scoreset_id, to_tsvector(XI.identifier) as text from target_genes TG join {x}_offsets XO on \ + (XO.target_gene_id = TG.id) join {x}_identifiers XI on (XI.id = XO.identifier_id)" + for x in ('uniprot', 'refseq', 'ensembl') + ] + ), + with_data=True +) + + +def scoreset_fulltext_create(session): + logger.warning("Creating %s", _scoreset_fulltext_view.signature) + session.execute( + _scoreset_fulltext_view.to_sql_statement_create() + ) + session.commit() + logger.warning("Created %s", _scoreset_fulltext_view.signature) + + +def scoreset_fulltext_destroy(session): + logger.warning("Destroying %s", _scoreset_fulltext_view.signature) + session.execute( + _scoreset_fulltext_view.to_sql_statement_drop() + ) + session.commit() + logger.warning("Destroyed %s", _scoreset_fulltext_view.signature) + + +def scoreset_fulltext_refresh(session): + session.execute(text(f'refresh materialized view {_scoreset_fulltext_view.signature}')) + session.commit() + + +def scoreset_fulltext_filter(query, string): + return query.filter(ScoreSet.id.in_( + text(f"select distinct scoreset_id from {_scoreset_fulltext_view.signature} \ + where text @@ websearch_to_tsquery(:text)").params(text=string) + )) diff --git a/src/mavedb/routers/score_sets.py b/src/mavedb/routers/score_sets.py index 1746e703..8aec0500 100644 --- a/src/mavedb/routers/score_sets.py +++ b/src/mavedb/routers/score_sets.py @@ -52,6 +52,7 @@ from mavedb.models.license import License from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet +from mavedb.models.score_set_fulltext import scoreset_fulltext_refresh from mavedb.models.target_accession import TargetAccession from mavedb.models.target_gene import TargetGene from mavedb.models.target_sequence import TargetSequence @@ -105,6 +106,19 @@ async def fetch_score_set_by_urn( return item +async def _refresh_scoreset_fulltext( + worker: ArqRedis, + item_id: Optional[int], +): + job = await worker.enqueue_job( + "refresh_scoreset_fulltext", + correlation_id_for_context(), + item_id, + ) + if job is not None: + save_to_logging_context({"refresh_scoreset_fulltext worker_job_id": job.job_id}) + + router = APIRouter( prefix="/api/v1", tags=["score sets"], @@ -319,6 +333,7 @@ async def create_score_set( item_create: score_set.ScoreSetCreate, db: Session = Depends(deps.get_db), user_data: UserData = Depends(require_current_user_with_email), + worker: ArqRedis = Depends(deps.get_worker), ) -> Any: """ Create a score set. @@ -583,6 +598,9 @@ async def create_score_set( db.refresh(item) save_to_logging_context({"created_resource": item.urn}) + + await _refresh_scoreset_fulltext(worker, item.id) + return item @@ -628,6 +646,7 @@ async def upload_score_set_variant_data( item.processing_state = ProcessingState.processing # await the insertion of this job into the worker queue, not the job itself. + logger.warning("enqueue create_variants_for_score_set %s", worker) job = await worker.enqueue_job( "create_variants_for_score_set", correlation_id_for_context(), @@ -872,6 +891,8 @@ async def update_score_set( # races the score set GET request). item.processing_state = ProcessingState.processing + logger.warning("ENQUEUE create_variants_for_score_set %s", item.id) + # await the insertion of this job into the worker queue, not the job itself. job = await worker.enqueue_job( "create_variants_for_score_set", @@ -919,6 +940,9 @@ async def update_score_set( db.refresh(item) save_to_logging_context({"updated_resource": item.urn}) + + await _refresh_scoreset_fulltext(worker, item.id) + return item @@ -928,6 +952,7 @@ async def delete_score_set( urn: str, db: Session = Depends(deps.get_db), user_data: UserData = Depends(require_current_user), + worker: ArqRedis = Depends(deps.get_worker), ) -> Any: """ Delete a score set. @@ -954,6 +979,8 @@ async def delete_score_set( db.delete(item) db.commit() + await _refresh_scoreset_fulltext(worker, item.id) + @router.post( "/score-sets/{urn}/publish", @@ -966,6 +993,7 @@ def publish_score_set( urn: str, db: Session = Depends(deps.get_db), user_data: UserData = Depends(require_current_user), + worker : ArqRedis = Depends(deps.get_worker), ) -> Any: """ Publish a score set. diff --git a/src/mavedb/worker/jobs.py b/src/mavedb/worker/jobs.py index fed261c8..1d1a8717 100644 --- a/src/mavedb/worker/jobs.py +++ b/src/mavedb/worker/jobs.py @@ -30,9 +30,9 @@ from mavedb.models.enums.processing_state import ProcessingState from mavedb.models.mapped_variant import MappedVariant from mavedb.models.score_set import ScoreSet +from mavedb.models.score_set_fulltext import scoreset_fulltext_refresh from mavedb.models.user import User from mavedb.models.variant import Variant - logger = logging.getLogger(__name__) MAPPING_QUEUE_NAME = "vrs_mapping_queue" @@ -653,3 +653,12 @@ async def variant_mapper_manager(ctx: dict, correlation_id: str, updater_id: int db.commit() return {"success": False, "enqueued_job": new_job_id} + + +def refresh_scoreset_fulltext(ctx: dict, item_id: Optional[int]) -> dict: + + logger.info("fresh_scoreset_fulltext %s running", item_id) + scoreset_fulltext_refresh(ctx["db"]) + logger.info("fresh_scoreset_fulltext %s finished", item_id) + + return {"success": True} diff --git a/src/mavedb/worker/settings.py b/src/mavedb/worker/settings.py index d91e48b8..83553105 100644 --- a/src/mavedb/worker/settings.py +++ b/src/mavedb/worker/settings.py @@ -8,13 +8,14 @@ from mavedb.data_providers.services import cdot_rest from mavedb.db.session import SessionLocal from mavedb.lib.logging.canonical import log_job -from mavedb.worker.jobs import create_variants_for_score_set, map_variants_for_score_set, variant_mapper_manager +from mavedb.worker.jobs import create_variants_for_score_set, map_variants_for_score_set, variant_mapper_manager, refresh_scoreset_fulltext # ARQ requires at least one task on startup. BACKGROUND_FUNCTIONS: list[Callable] = [ create_variants_for_score_set, variant_mapper_manager, map_variants_for_score_set, + refresh_scoreset_fulltext, ] BACKGROUND_CRONJOBS: list[CronJob] = [] diff --git a/tests/conftest.py b/tests/conftest.py index b58a5dd9..0b70fcc0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ from inspect import getsourcefile from os.path import abspath from unittest.mock import patch +import logging import cdot.hgvs.dataproviders import email_validator @@ -28,11 +29,14 @@ from mavedb.models.user import User from mavedb.server_main import app from mavedb.worker.jobs import create_variants_for_score_set, map_variants_for_score_set, variant_mapper_manager +from mavedb.models.score_set_fulltext import scoreset_fulltext_create, scoreset_fulltext_destroy sys.path.append(".") from tests.helpers.constants import ADMIN_USER, TEST_USER +logger = logging.getLogger(__name__) + # needs the pytest_postgresql plugin installed assert pytest_postgresql.factories @@ -51,10 +55,12 @@ def session(postgresql): session = sessionmaker(autocommit=False, autoflush=False, bind=engine)() Base.metadata.create_all(bind=engine) + scoreset_fulltext_create(session) try: yield session finally: + scoreset_fulltext_destroy(session) session.close() Base.metadata.drop_all(bind=engine) diff --git a/tests/helpers/util.py b/tests/helpers/util.py index cda39c99..f26f79e5 100644 --- a/tests/helpers/util.py +++ b/tests/helpers/util.py @@ -68,10 +68,12 @@ def create_seq_score_set(client, experiment_urn, update=None): score_set_payload.update(update) jsonschema.validate(instance=score_set_payload, schema=ScoreSetCreate.schema()) - response = client.post("/api/v1/score-sets/", json=score_set_payload) - assert ( - response.status_code == 200 - ), f"Could not create sequence based score set (no variants) within experiment {experiment_urn}" + with patch.object(ArqRedis, "enqueue_job", return_value=None) as queue: + response = client.post("/api/v1/score-sets/", json=score_set_payload) + assert ( + response.status_code == 200 + ), f"Could not create sequence based score set (no variants) within experiment {experiment_urn}" + queue.assert_called_once() response_data = response.json() jsonschema.validate(instance=response_data, schema=ScoreSet.schema()) @@ -86,12 +88,14 @@ def create_acc_score_set(client, experiment_urn, update=None): score_set_payload.update(update) jsonschema.validate(instance=score_set_payload, schema=ScoreSetCreate.schema()) - with patch.object(cdot.hgvs.dataproviders.RESTDataProvider, "_get_transcript", return_value=TEST_CDOT_TRANSCRIPT): + with patch.object(cdot.hgvs.dataproviders.RESTDataProvider, "_get_transcript", return_value=TEST_CDOT_TRANSCRIPT), \ + patch.object(ArqRedis, "enqueue_job", return_value=None) as queue: response = client.post("/api/v1/score-sets/", json=score_set_payload) assert ( response.status_code == 200 ), f"Could not create accession based score set (no variants) within experiment {experiment_urn}" + queue.assert_called_once() response_data = response.json() jsonschema.validate(instance=response_data, schema=ScoreSet.schema()) diff --git a/tests/routers/test_score_set.py b/tests/routers/test_score_set.py index 44207f97..d91a9ec9 100644 --- a/tests/routers/test_score_set.py +++ b/tests/routers/test_score_set.py @@ -2,6 +2,7 @@ from copy import deepcopy from datetime import date from unittest.mock import patch +import logging import jsonschema from arq import ArqRedis @@ -10,6 +11,7 @@ from mavedb.models.enums.processing_state import ProcessingState from mavedb.models.experiment import Experiment as ExperimentDbModel from mavedb.models.score_set import ScoreSet as ScoreSetDbModel +from mavedb.models.score_set_fulltext import scoreset_fulltext_refresh from mavedb.view_models.orcid import OrcidUser from mavedb.view_models.score_set import ScoreSet, ScoreSetCreate from tests.helpers.constants import ( @@ -30,6 +32,8 @@ ) +logger = logging.getLogger(__name__) + def test_TEST_MINIMAL_SEQ_SCORESET_is_valid(): jsonschema.validate(instance=TEST_MINIMAL_SEQ_SCORESET, schema=ScoreSetCreate.schema()) @@ -39,11 +43,13 @@ def test_TEST_MINIMAL_ACC_SCORESET_is_valid(): def test_create_minimal_score_set(client, setup_router_db): - experiment = create_experiment(client) - score_set_post_payload = deepcopy(TEST_MINIMAL_SEQ_SCORESET) - score_set_post_payload["experimentUrn"] = experiment["urn"] - response = client.post("/api/v1/score-sets/", json=score_set_post_payload) - assert response.status_code == 200 + with patch.object(ArqRedis, "enqueue_job", return_value=None) as queue: + experiment = create_experiment(client) + score_set_post_payload = deepcopy(TEST_MINIMAL_SEQ_SCORESET) + score_set_post_payload["experimentUrn"] = experiment["urn"] + response = client.post("/api/v1/score-sets/", json=score_set_post_payload) + assert response.status_code == 200 + queue.assert_called_once() response_data = response.json() jsonschema.validate(instance=response_data, schema=ScoreSet.schema()) assert isinstance(MAVEDB_TMP_URN_RE.fullmatch(response_data["urn"]), re.Match) @@ -69,11 +75,13 @@ def test_create_score_set_with_contributor(client, setup_router_db): score_set["experimentUrn"] = experiment["urn"] score_set.update({"contributors": [{"orcid_id": TEST_ORCID_ID}]}) - with patch( - "mavedb.lib.orcid.fetch_orcid_user", - lambda orcid_id: OrcidUser(orcid_id=orcid_id, given_name="ORCID", family_name="User"), - ): + def orcid_user_mock(orcid_id): + return OrcidUser(orcid_id=orcid_id, given_name="ORCID", family_name="User") + + with patch.object(ArqRedis, "enqueue_job", return_value=None) as queue, \ + patch("mavedb.lib.orcid.fetch_orcid_user", orcid_user_mock): response = client.post("/api/v1/score-sets/", json=score_set) + queue.assert_called_once() assert response.status_code == 200 response_data = response.json() @@ -124,8 +132,10 @@ def test_create_score_set_with_score_range(client, setup_router_db): } ) - response = client.post("/api/v1/score-sets/", json=score_set) - assert response.status_code == 200 + with patch.object(ArqRedis, "enqueue_job", return_value=None) as queue: + response = client.post("/api/v1/score-sets/", json=score_set) + assert response.status_code == 200 + queue.assert_called_once() response_data = response.json() jsonschema.validate(instance=response_data, schema=ScoreSet.schema()) @@ -979,6 +989,9 @@ def test_search_score_sets_no_match(session, data_provider, client, setup_router update={"title": "Test Score Set"}, ) + # this would be run asynchronously but that is mocked out so we run it manually + scoreset_fulltext_refresh(session) + search_payload = {"text": "fnord"} response = client.post("/api/v1/score-sets/search", json=search_payload) assert response.status_code == 200 @@ -996,6 +1009,9 @@ def test_search_score_sets_match(session, data_provider, client, setup_router_db update={"title": "Test Fnord Score Set"}, ) + # this would be run asynchronously but that is mocked out so we run it manually + scoreset_fulltext_refresh(session) + search_payload = {"text": "fnord"} response = client.post("/api/v1/score-sets/search", json=search_payload) assert response.status_code == 200 @@ -1069,9 +1085,11 @@ def test_can_delete_own_private_scoreset(session, data_provider, client, setup_r client, session, data_provider, experiment["urn"], data_files / "scores.csv" ) - response = client.delete(f"/api/v1/score-sets/{score_set['urn']}") + with patch.object(ArqRedis, "enqueue_job", return_value=None) as queue: + response = client.delete(f"/api/v1/score-sets/{score_set['urn']}") assert response.status_code == 200 + queue.assert_called_once() def test_cannot_delete_own_published_scoreset(session, data_provider, client, setup_router_db, data_files): @@ -1106,9 +1124,11 @@ def test_contributor_can_delete_other_users_private_scoreset( TEST_USER["last_name"], ) - response = client.delete(f"/api/v1/score-sets/{score_set['urn']}") + with patch.object(ArqRedis, "enqueue_job", return_value=None) as queue: + response = client.delete(f"/api/v1/score-sets/{score_set['urn']}") assert response.status_code == 200 + queue.assert_called_once() def test_admin_can_delete_other_users_private_scoreset( @@ -1119,10 +1139,11 @@ def test_admin_can_delete_other_users_private_scoreset( client, session, data_provider, experiment["urn"], data_files / "scores.csv" ) - with DependencyOverrider(admin_app_overrides): + with DependencyOverrider(admin_app_overrides), patch.object(ArqRedis, "enqueue_job", return_value=None) as queue: response = client.delete(f"/api/v1/score-sets/{score_set['urn']}") assert response.status_code == 200 + queue.assert_called_once() def test_admin_can_delete_other_users_published_scoreset( @@ -1135,18 +1156,21 @@ def test_admin_can_delete_other_users_published_scoreset( response = client.post(f"/api/v1/score-sets/{score_set['urn']}/publish") response_data = response.json() - with DependencyOverrider(admin_app_overrides): + with DependencyOverrider(admin_app_overrides), patch.object(ArqRedis, "enqueue_job", return_value=None) as queue: del_response = client.delete(f"/api/v1/score-sets/{response_data['urn']}") assert del_response.status_code == 200 + queue.assert_called_once() def test_can_add_score_set_to_own_private_experiment(session, client, setup_router_db): experiment = create_experiment(client) score_set_post_payload = deepcopy(TEST_MINIMAL_SEQ_SCORESET) score_set_post_payload["experimentUrn"] = experiment["urn"] - response = client.post("/api/v1/score-sets/", json=score_set_post_payload) - assert response.status_code == 200 + with patch.object(ArqRedis, "enqueue_job", return_value=None) as queue: + response = client.post("/api/v1/score-sets/", json=score_set_post_payload) + assert response.status_code == 200 + queue.assert_called_once() def test_cannot_add_score_set_to_others_private_experiment(session, client, setup_router_db): @@ -1169,8 +1193,10 @@ def test_can_add_score_set_to_own_public_experiment(session, data_provider, clie pub_score_set_1 = client.post(f"/api/v1/score-sets/{score_set_1['urn']}/publish").json() score_set_2 = deepcopy(TEST_MINIMAL_SEQ_SCORESET) score_set_2["experimentUrn"] = pub_score_set_1["experiment"]["urn"] - response = client.post("/api/v1/score-sets/", json=score_set_2) - assert response.status_code == 200 + with patch.object(ArqRedis, "enqueue_job", return_value=None) as queue: + response = client.post("/api/v1/score-sets/", json=score_set_2) + assert response.status_code == 200 + queue.assert_called_once() def test_can_add_score_set_to_others_public_experiment(session, data_provider, client, setup_router_db, data_files): @@ -1182,8 +1208,10 @@ def test_can_add_score_set_to_others_public_experiment(session, data_provider, c change_ownership(session, pub_score_set_1["experiment"]["urn"], ExperimentDbModel) score_set_2 = deepcopy(TEST_MINIMAL_SEQ_SCORESET) score_set_2["experimentUrn"] = pub_score_set_1["experiment"]["urn"] - response = client.post("/api/v1/score-sets/", json=score_set_2) - assert response.status_code == 200 + with patch.object(ArqRedis, "enqueue_job", return_value=None) as queue: + response = client.post("/api/v1/score-sets/", json=score_set_2) + assert response.status_code == 200 + queue.assert_called_once() def test_contributor_can_add_score_set_to_others_private_experiment(session, client, setup_router_db): @@ -1199,8 +1227,10 @@ def test_contributor_can_add_score_set_to_others_private_experiment(session, cli ) score_set_post_payload = deepcopy(TEST_MINIMAL_SEQ_SCORESET) score_set_post_payload["experimentUrn"] = experiment["urn"] - response = client.post("/api/v1/score-sets/", json=score_set_post_payload) - assert response.status_code == 200 + with patch.object(ArqRedis, "enqueue_job", return_value=None) as queue: + response = client.post("/api/v1/score-sets/", json=score_set_post_payload) + assert response.status_code == 200 + queue.assert_called_once() def test_contributor_can_add_score_set_to_others_public_experiment( @@ -1222,5 +1252,7 @@ def test_contributor_can_add_score_set_to_others_public_experiment( ) score_set_post_payload = deepcopy(TEST_MINIMAL_SEQ_SCORESET) score_set_post_payload["experimentUrn"] = published_score_set["experiment"]["urn"] - response = client.post("/api/v1/score-sets/", json=score_set_post_payload) - assert response.status_code == 200 + with patch.object(ArqRedis, "enqueue_job", return_value=None) as queue: + response = client.post("/api/v1/score-sets/", json=score_set_post_payload) + assert response.status_code == 200 + queue.assert_called_once() diff --git a/tests/worker/test_jobs.py b/tests/worker/test_jobs.py index 18e0846a..bb1eedb6 100644 --- a/tests/worker/test_jobs.py +++ b/tests/worker/test_jobs.py @@ -30,6 +30,7 @@ create_variants_for_score_set, map_variants_for_score_set, variant_mapper_manager, + refresh_scoreset_fulltext, ) from tests.helpers.constants import ( TEST_CDOT_TRANSCRIPT, @@ -53,8 +54,10 @@ async def setup_records_and_files(async_client, data_files, input_score_set): score_set_payload = deepcopy(input_score_set) score_set_payload["experimentUrn"] = experiment["urn"] jsonschema.validate(instance=score_set_payload, schema=ScoreSetCreate.schema()) - score_set_response = await async_client.post("/api/v1/score-sets/", json=score_set_payload) - assert score_set_response.status_code == 200 + with patch.object(ArqRedis, "enqueue_job", return_value=None) as queue: + score_set_response = await async_client.post("/api/v1/score-sets/", json=score_set_payload) + assert score_set_response.status_code == 200 + queue.assert_called_once() score_set = score_set_response.json() jsonschema.validate(instance=score_set, schema=ScoreSet.schema()) @@ -1457,3 +1460,18 @@ async def failed_mapping_job(): assert len(mapped_variants_for_score_set) == 0 assert score_set.mapping_state == MappingState.failed assert score_set.mapping_errors is not None + + +@pytest.mark.asyncio +@pytest.mark.skip +async def test_refresh_scoreset_fulltext( + setup_worker_db, standalone_worker_context, session, async_client, data_files, arq_worker, arq_redis +): + with ( + patch('mavedb.models.score_set_fulltext.scoreset_fulltext_refresh') as mock + ): + await arq_redis.enqueue_job("refresh_scoreset_fulltext", {"db": None}, -1) + await arq_worker.async_run() + await arq_worker.run_check() + mock.assert_run_once() +