diff --git a/.github/workflows/test_app.yml b/.github/workflows/test_app.yml index 8005b33d..02df2d67 100644 --- a/.github/workflows/test_app.yml +++ b/.github/workflows/test_app.yml @@ -3,16 +3,52 @@ name: Test Source Collector App on: pull_request +#jobs: +# build: +# runs-on: ubuntu-latest +# steps: +# - name: Checkout repository +# uses: actions/checkout@v4 +# - name: Run docker-compose +# uses: hoverkraft-tech/compose-action@v2.0.1 +# with: +# compose-file: "docker-compose.yml" +# - name: Execute tests in the running service +# run: | +# docker ps -a && docker exec data-source-identification-app-1 pytest /app/tests/test_automated + jobs: - build: - runs-on: ubuntu-latest - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - name: Run docker-compose - uses: hoverkraft-tech/compose-action@v2.0.1 - with: - compose-file: "docker-compose.yml" - - name: Execute tests in the running service - run: | - docker exec data-source-identification-app-1 pytest /app/tests/test_automated \ No newline at end of file + container-job: + runs-on: ubuntu-latest + container: python:3.12.8 + + services: + postgres: + image: postgres:15 + env: + POSTGRES_PASSWORD: postgres + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Run tests + run: | + pytest tests/test_automated + pytest tests/test_alembic + env: + POSTGRES_PASSWORD: postgres + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_HOST: postgres + POSTGRES_PORT: 5432 + GOOGLE_API_KEY: TEST + GOOGLE_CSE_ID: TEST diff --git a/Dockerfile b/Dockerfile index b7e9a5b8..8e64b85d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,5 +14,4 @@ RUN pip install --no-cache-dir -r requirements.txt # Expose the application port EXPOSE 80 -# Run FastAPI app with uvicorn -CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "80"] \ No newline at end of file +RUN chmod +x execute.sh \ No newline at end of file diff --git a/ENV.md b/ENV.md index b3d6af1e..a8210fb9 100644 --- a/ENV.md +++ b/ENV.md @@ -15,3 +15,4 @@ Please ensure these are properly defined in a `.env` file in the root directory. |`POSTGRES_HOST` | The host for the test database | `127.0.0.1` | |`POSTGRES_PORT` | The port for the test database | `5432` | |`DS_APP_SECRET_KEY`| The secret key used for decoding JWT tokens produced by the Data Sources App. Must match the secret token that is used in the Data Sources App for encoding. |`abc123`| +|`DEV`| Set to any value to run the application in development mode. |`true`| diff --git a/README.md b/README.md index b37d9cbc..5a39d2bd 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,12 @@ This can be done via the following command: docker compose up -d ``` +Following that, you will need to set up the uvicorn server using the following command: + +```bash +docker exec data-source-identification-app-1 uvicorn api.main:app --host 0.0.0.0 --port 80 +``` + Note that while the container may mention the web app running on `0.0.0.0:8000`, the actual host may be `127.0.0.1:8000`. To access the API documentation, visit `http://{host}:8000/docs`. diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 00000000..7cc1a0d5 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,117 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = collector_db/alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +# version_path_separator = newline +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = postgresql://test_source_collector_user:HanviliciousHamiltonHilltops@host.docker.internal:5432/source_collector_test_db + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/api/main.py b/api/main.py index 7c632e8f..c5f76385 100644 --- a/api/main.py +++ b/api/main.py @@ -9,12 +9,14 @@ from collector_db.DatabaseClient import DatabaseClient from core.CoreLogger import CoreLogger from core.SourceCollectorCore import SourceCollectorCore +from util.helper_functions import get_from_env @asynccontextmanager async def lifespan(app: FastAPI): # Initialize shared dependencies db_client = DatabaseClient() + await setup_database(db_client) source_collector_core = SourceCollectorCore( core_logger=CoreLogger( db_client=db_client @@ -34,6 +36,15 @@ async def lifespan(app: FastAPI): pass +async def setup_database(db_client): + # Initialize database if dev environment, otherwise apply migrations + try: + get_from_env("DEV") + db_client.init_db() + except Exception as e: + return + + app = FastAPI( title="Source Collector API", description="API for collecting data sources", diff --git a/apply_migrations.py b/apply_migrations.py new file mode 100644 index 00000000..5be4cd99 --- /dev/null +++ b/apply_migrations.py @@ -0,0 +1,14 @@ +from alembic import command +from alembic.config import Config + +from collector_db.helper_functions import get_postgres_connection_string + +if __name__ == "__main__": + print("Applying migrations...") + alembic_config = Config("alembic.ini") + alembic_config.set_main_option( + "sqlalchemy.url", + get_postgres_connection_string() + ) + command.upgrade(alembic_config, "head") + print("Migrations applied.") \ No newline at end of file diff --git a/collector_db/DTOs/URLInfo.py b/collector_db/DTOs/URLInfo.py index 553a76a9..8abd3e4a 100644 --- a/collector_db/DTOs/URLInfo.py +++ b/collector_db/DTOs/URLInfo.py @@ -1,3 +1,4 @@ +import datetime from typing import Optional from pydantic import BaseModel @@ -11,3 +12,4 @@ class URLInfo(BaseModel): url: str url_metadata: Optional[dict] = None outcome: URLOutcome = URLOutcome.PENDING + updated_at: Optional[datetime.datetime] = None diff --git a/collector_db/DatabaseClient.py b/collector_db/DatabaseClient.py index cce6b954..2d23842f 100644 --- a/collector_db/DatabaseClient.py +++ b/collector_db/DatabaseClient.py @@ -16,7 +16,6 @@ from collector_db.helper_functions import get_postgres_connection_string from collector_db.models import Base, Batch, URL, Log, Duplicate from collector_manager.enums import CollectorType -from core.DTOs.BatchStatusInfo import BatchStatusInfo from core.enums import BatchStatus @@ -32,10 +31,12 @@ def __init__(self, db_url: str = get_postgres_connection_string()): url=db_url, echo=ConfigManager.get_sqlalchemy_echo(), ) - Base.metadata.create_all(self.engine) self.session_maker = scoped_session(sessionmaker(bind=self.engine)) self.session = None + def init_db(self): + Base.metadata.create_all(self.engine) + def session_manager(method): @wraps(method) def wrapper(self, *args, **kwargs): @@ -214,13 +215,13 @@ def get_recent_batch_status_info( # Get only the batch_id, collector_type, status, and created_at limit = 100 query = (session.query(Batch) - .order_by(Batch.date_generated.desc()) - .limit(limit) - .offset((page - 1) * limit)) + .order_by(Batch.date_generated.desc())) if collector_type: query = query.filter(Batch.strategy == collector_type.value) if status: query = query.filter(Batch.status == status.value) + query = (query.limit(limit) + .offset((page - 1) * limit)) batches = query.all() return [BatchInfo(**batch.__dict__) for batch in batches] @@ -274,6 +275,11 @@ def delete_old_logs(self, session): Log.created_at < datetime.now() - timedelta(days=1) ).delete() + @session_manager + def update_url(self, session, url_info: URLInfo): + url = session.query(URL).filter_by(id=url_info.id).first() + url.url_metadata = url_info.url_metadata + if __name__ == "__main__": client = DatabaseClient() print("Database client initialized.") diff --git a/collector_db/alembic/README.md b/collector_db/alembic/README.md index ea95ae53..a31e2127 100644 --- a/collector_db/alembic/README.md +++ b/collector_db/alembic/README.md @@ -1,4 +1,30 @@ -Generic single-database configuration. +Alembic is a lightweight Python library that helps manage database migrations. +## Files and Directories + +The following files are present in this directory OR related to it: +- `script.py.mako`: This is a Mako template file which is used to generate new migration scripts. Whatever is here is used to generate new files within `versions/`. This is scriptable so that the structure of each migration file can be controlled, including standard imports to be within each, as well as changes to the structure of the `upgrade()` and `downgrade()` functions +- `env.py`: The main script that sets up the migration environment. +- `alembic.ini`: The `alembic` configuration file. Located in the root of the repository +- `/versions`: The directory which contains the migration scripts +- `apply_migrations.py`: A Python script, located in the root directory, which applies any outstanding migrations to the database +- `execute.sh`: A shell script in the root directory which runs the `apply_migrations.py` script. Called by DigitalOcean when deploying the application. + +## Generating a Migration + +To generate a new migration, run the following command from the root directory: + +```bash +alembic revision --autogenerate -m "Description for migration" +``` + +Then, locate the new revision script in `/versions` and modify the update and downgrade functions as needed + +Once you have generated a new migration, you can upgrade and downgrade the database using the `alembic` command line tool. + +Finally, make sure to commit your changes to the repository. + +## How does Alembic Work? + +As long as new migrations are generated and stored in the `/versions` directory, Alembic will apply them, in the order they were made, to the production database. -- `script.py.mako`: This is a Mako template file which is used to generate new migration scripts. Whatever is here is used to generate new files within `versions/`. This is scriptable so that the structure of each migration file can be controlled, including standard imports to be within each, as well as changes to the structure of the `upgrade()` and `downgrade()` functions \ No newline at end of file diff --git a/collector_db/alembic/env.py b/collector_db/alembic/env.py index 36112a3c..e89cf160 100644 --- a/collector_db/alembic/env.py +++ b/collector_db/alembic/env.py @@ -5,6 +5,9 @@ from alembic import context +from collector_db.helper_functions import get_postgres_connection_string +from collector_db.models import Base + # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config @@ -18,7 +21,7 @@ # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata -target_metadata = None +target_metadata = Base.metadata # other values from the config, defined by the needs of env.py, # can be acquired: @@ -38,7 +41,7 @@ def run_migrations_offline() -> None: script output. """ - url = config.get_main_option("sqlalchemy.url") + url = get_postgres_connection_string() context.configure( url=url, target_metadata=target_metadata, diff --git a/collector_db/alembic/versions/a4750e7ff8e7_add_updated_at_to_url_table.py b/collector_db/alembic/versions/a4750e7ff8e7_add_updated_at_to_url_table.py new file mode 100644 index 00000000..f4084cf2 --- /dev/null +++ b/collector_db/alembic/versions/a4750e7ff8e7_add_updated_at_to_url_table.py @@ -0,0 +1,59 @@ +"""Add updated_at to URL table + +Revision ID: a4750e7ff8e7 +Revises: d11f07224d1f +Create Date: 2025-01-08 10:25:04.031123 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'a4750e7ff8e7' +down_revision: Union[str, None] = 'd11f07224d1f' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add `updated_at` column to the `URL` table + op.add_column( + table_name='urls', + column=sa.Column( + 'updated_at', + sa.DateTime(), + server_default=sa.text('CURRENT_TIMESTAMP'), + nullable=True + ) + ) + + + # Create a function and trigger to update the `updated_at` column + op.execute(""" + CREATE OR REPLACE FUNCTION update_updated_at_column() + RETURNS TRIGGER AS $$ + BEGIN + NEW.updated_at = CURRENT_TIMESTAMP; + RETURN NEW; + END; + $$ LANGUAGE plpgsql; + """) + + op.execute(""" + CREATE TRIGGER set_updated_at + BEFORE UPDATE ON urls + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + """) + + +def downgrade() -> None: + # Remove `updated_at` column from the `URL` table + op.drop_column('urls', 'updated_at') + + # Drop the trigger and function + op.execute("DROP TRIGGER IF EXISTS set_updated_at ON urls;") + op.execute("DROP FUNCTION IF EXISTS update_updated_at_column;") diff --git a/collector_db/alembic/versions/d11f07224d1f_initial_creation.py b/collector_db/alembic/versions/d11f07224d1f_initial_creation.py new file mode 100644 index 00000000..d3ad9d8c --- /dev/null +++ b/collector_db/alembic/versions/d11f07224d1f_initial_creation.py @@ -0,0 +1,88 @@ +"""Initial creation + +Revision ID: d11f07224d1f +Revises: +Create Date: 2025-01-07 17:41:35.512410 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'd11f07224d1f' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('batches', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('strategy', sa.String(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('status', sa.String(), nullable=False), + sa.Column('total_url_count', sa.Integer(), nullable=False), + sa.Column('original_url_count', sa.Integer(), nullable=False), + sa.Column('duplicate_url_count', sa.Integer(), nullable=False), + sa.Column('date_generated', sa.TIMESTAMP(), server_default=sa.text('now()'), nullable=False), + sa.Column('strategy_success_rate', sa.Float(), nullable=True), + sa.Column('metadata_success_rate', sa.Float(), nullable=True), + sa.Column('agency_match_rate', sa.Float(), nullable=True), + sa.Column('record_type_match_rate', sa.Float(), nullable=True), + sa.Column('record_category_match_rate', sa.Float(), nullable=True), + sa.Column('compute_time', sa.Float(), nullable=True), + sa.Column('parameters', sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('logs', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('batch_id', sa.Integer(), nullable=False), + sa.Column('log', sa.Text(), nullable=False), + sa.Column('created_at', sa.TIMESTAMP(), server_default=sa.text('now()'), nullable=False), + sa.ForeignKeyConstraint(['batch_id'], ['batches.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('missing', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('place_id', sa.Integer(), nullable=False), + sa.Column('record_type', sa.String(), nullable=False), + sa.Column('batch_id', sa.Integer(), nullable=True), + sa.Column('strategy_used', sa.Text(), nullable=False), + sa.Column('date_searched', sa.TIMESTAMP(), server_default=sa.text('now()'), nullable=False), + sa.ForeignKeyConstraint(['batch_id'], ['batches.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('urls', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('batch_id', sa.Integer(), nullable=False), + sa.Column('url', sa.Text(), nullable=True), + sa.Column('url_metadata', sa.JSON(), nullable=True), + sa.Column('outcome', sa.String(), nullable=True), + sa.Column('created_at', sa.TIMESTAMP(), server_default=sa.text('now()'), nullable=False), + sa.ForeignKeyConstraint(['batch_id'], ['batches.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('url') + ) + op.create_table('duplicates', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('batch_id', sa.Integer(), nullable=False), + sa.Column('original_url_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['batch_id'], ['batches.id'], ), + sa.ForeignKeyConstraint(['original_url_id'], ['urls.id'], ), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('duplicates') + op.drop_table('urls') + op.drop_table('missing') + op.drop_table('logs') + op.drop_table('batches') + # ### end Alembic commands ### diff --git a/collector_db/models.py b/collector_db/models.py index d6b06b01..5ceca48a 100644 --- a/collector_db/models.py +++ b/collector_db/models.py @@ -63,6 +63,7 @@ class URL(Base): # The outcome of the URL: submitted, human_labeling, rejected, duplicate, etc. outcome = Column(String) created_at = Column(TIMESTAMP, nullable=False, server_default=CURRENT_TIME_SERVER_DEFAULT) + updated_at = Column(TIMESTAMP, nullable=False, server_default=CURRENT_TIME_SERVER_DEFAULT) batch = relationship("Batch", back_populates="urls") duplicates = relationship("Duplicate", back_populates="original_url") diff --git a/collector_manager/CollectorBase.py b/collector_manager/CollectorBase.py index 828a0878..87b0a9d8 100644 --- a/collector_manager/CollectorBase.py +++ b/collector_manager/CollectorBase.py @@ -65,6 +65,14 @@ def handle_error(self, e: Exception) -> None: if self.raise_error: raise e self.log(f"Error: {e}") + self.db_client.update_batch_post_collection( + batch_id=self.batch_id, + batch_status=self.status, + compute_time=self.compute_time, + total_url_count=0, + original_url_count=0, + duplicate_url_count=0 + ) def process(self) -> None: self.log("Processing collector...", allow_abort=False) diff --git a/core/CoreLogger.py b/core/CoreLogger.py index b24ae18b..6ddfd68f 100644 --- a/core/CoreLogger.py +++ b/core/CoreLogger.py @@ -92,6 +92,6 @@ def shutdown(self): Stops the logger gracefully and flushes any remaining logs. """ self.stop_event.set() - if self.flush_future and not self.flush_future.done(): - self.flush_future.result(timeout=10) + # if self.flush_future and not self.flush_future.done(): + self.flush_future.result(timeout=10) self.flush_all() # Flush remaining logs diff --git a/core/ScheduledTaskManager.py b/core/ScheduledTaskManager.py index 81d5bf1c..f664af5c 100644 --- a/core/ScheduledTaskManager.py +++ b/core/ScheduledTaskManager.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timedelta from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.triggers.interval import IntervalTrigger @@ -26,7 +26,7 @@ def add_scheduled_tasks(self): self.db_client.delete_old_logs, trigger=IntervalTrigger( days=1, - start_date=datetime.now() + start_date=datetime.now() + timedelta(minutes=10) ) ) diff --git a/docker-compose.yml b/docker-compose.yml index 6f618e24..d813a97f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,6 +6,7 @@ services: build: context: . dockerfile: Dockerfile + command: pytest /app/tests/test_automated ports: - "8000:80" environment: @@ -14,7 +15,7 @@ services: - POSTGRES_DB=source_collector_test_db # For local development in non-Linux environment # - POSTGRES_HOST=host.docker.internal -# For GitHub Actions (which use Linux Docker +# For GitHub Actions (which use Linux Docker) - POSTGRES_HOST=172.17.0.1 - POSTGRES_PORT=5432 - GOOGLE_API_KEY=TEST diff --git a/execute.sh b/execute.sh new file mode 100644 index 00000000..6bfd03b1 --- /dev/null +++ b/execute.sh @@ -0,0 +1,4 @@ +#!/bin/sh + +python apply_migrations.py +uvicorn api.main:app --host 0.0.0.0 --port 80 diff --git a/requirements.txt b/requirements.txt index b3f4999f..d3e7f22d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,9 +31,10 @@ httpx~=0.28.1 ckanapi~=4.8 psycopg[binary]~=3.1.20 APScheduler~=3.11.0 +alembic~=1.14.0 # Security Manager PyJWT~=2.10.1 # Tests -pytest-timeout~=2.3.1 \ No newline at end of file +pytest-timeout~=2.3.1 diff --git a/source_collectors/ckan/CKANAPIInterface.py b/source_collectors/ckan/CKANAPIInterface.py index a87921fa..52384f4e 100644 --- a/source_collectors/ckan/CKANAPIInterface.py +++ b/source_collectors/ckan/CKANAPIInterface.py @@ -1,6 +1,9 @@ from typing import Optional -from ckanapi import RemoteCKAN +from ckanapi import RemoteCKAN, NotFound + +class CKANAPIError(Exception): + pass # TODO: Maybe return Base Models? @@ -10,17 +13,23 @@ class CKANAPIInterface: """ def __init__(self, base_url: str): + self.base_url = base_url self.remote = RemoteCKAN(base_url, get_only=True) def package_search(self, query: str, rows: int, start: int, **kwargs): return self.remote.action.package_search(q=query, rows=rows, start=start, **kwargs) - # TODO: Add Schema def get_organization(self, organization_id: str): - return self.remote.action.organization_show(id=organization_id, include_datasets=True) - # TODO: Add Schema + try: + return self.remote.action.organization_show(id=organization_id, include_datasets=True) + except NotFound as e: + raise CKANAPIError(f"Organization {organization_id} not found" + f" for url {self.base_url}. Original error: {e}") def get_group_package(self, group_package_id: str, limit: Optional[int]): - return self.remote.action.group_package_show(id=group_package_id, limit=limit) - # TODO: Add Schema + try: + return self.remote.action.group_package_show(id=group_package_id, limit=limit) + except NotFound as e: + raise CKANAPIError(f"Group Package {group_package_id} not found" + f" for url {self.base_url}. Original error: {e}") diff --git a/source_collectors/ckan/DTOs.py b/source_collectors/ckan/DTOs.py index c6c1b683..992bb0b6 100644 --- a/source_collectors/ckan/DTOs.py +++ b/source_collectors/ckan/DTOs.py @@ -1,13 +1,22 @@ +from typing import Optional + from pydantic import BaseModel, Field +url_field = Field(description="The base CKAN URL to search from.") class CKANPackageSearchDTO(BaseModel): - url: str = Field(description="The package of the CKAN instance.") - terms: list[str] = Field(description="The search terms to use.") + url: str = url_field + terms: Optional[list[str]] = Field( + description="The search terms to use to refine the packages returned. " + "None will return all packages.", + default=None + ) class GroupAndOrganizationSearchDTO(BaseModel): - url: str = Field(description="The group or organization of the CKAN instance.") - ids: list[str] = Field(description="The ids of the group or organization.") + url: str = url_field + ids: Optional[list[str]] = Field( + description="The ids of the group or organization to get packages from." + ) class CKANInputDTO(BaseModel): package_search: list[CKANPackageSearchDTO] or None = Field( diff --git a/source_collectors/ckan/scrape_ckan_data_portals.py b/source_collectors/ckan/scrape_ckan_data_portals.py index b15421fc..46ef8ccb 100644 --- a/source_collectors/ckan/scrape_ckan_data_portals.py +++ b/source_collectors/ckan/scrape_ckan_data_portals.py @@ -24,6 +24,8 @@ def perform_search( :param search_func: The search function to execute. :param search_terms: The list of urls and search terms. + In the package search template, this is "url", "terms" + In the group and organization search template, this is "url", "ids" :param results: The list of results. :return: Updated list of results. """ diff --git a/tests/manual/source_collectors/test_ckan_collector.py b/tests/manual/source_collectors/test_ckan_collector.py index 870b9780..020b36c3 100644 --- a/tests/manual/source_collectors/test_ckan_collector.py +++ b/tests/manual/source_collectors/test_ckan_collector.py @@ -18,7 +18,7 @@ class CKANSchema(Schema): source_last_updated = fields.String() -def test_ckan_collector(): +def test_ckan_collector_default(): collector = CKANCollector( batch_id=1, dto=CKANInputDTO( @@ -36,3 +36,45 @@ def test_ckan_collector(): collector.run() schema = CKANSchema(many=True) schema.load(collector.data["results"]) + +def test_ckan_collector_custom(): + collector = CKANCollector( + batch_id=1, + dto=CKANInputDTO( + **{ + "package_search": [ + { + "url": "https://catalog.data.gov/", + "terms": [ + "police", + "crime", + "tags:(court courts court-cases criminal-justice-system law-enforcement law-enforcement-agencies)" + ] + } + ], + "group_search": [ + { + "url": "https://catalog.data.gov/", + "ids": [ + "3c648d96-0a29-4deb-aa96-150117119a23", + "92654c61-3a7d-484f-a146-257c0f6c55aa" + ] + } + ], + "organization_search": [ + { + "url": "https://data.houstontx.gov/", + "ids": [ + "https://data.houstontx.gov/" + ] + } + ] + } + ), + logger=MagicMock(spec=CoreLogger), + db_client=MagicMock(spec=DatabaseClient), + raise_error=True + ) + collector.run() + schema = CKANSchema(many=True) + schema.load(collector.data["results"]) \ No newline at end of file diff --git a/tests/test_alembic/__init__.py b/tests/test_alembic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_alembic/test_revisions.py b/tests/test_alembic/test_revisions.py new file mode 100644 index 00000000..50bf15ac --- /dev/null +++ b/tests/test_alembic/test_revisions.py @@ -0,0 +1,151 @@ +import time +from dataclasses import dataclass + +import pytest +from alembic import command +from alembic.config import Config + +from sqlalchemy import create_engine, Inspector, inspect, MetaData, Connection, Engine +from sqlalchemy.orm import Session, sessionmaker, scoped_session + +from collector_db.helper_functions import get_postgres_connection_string +from collector_db.models import Base, URL, Batch + + +@pytest.fixture() +def alembic_config(): + alembic_cfg = Config("alembic.ini") + yield alembic_cfg + +@pytest.fixture() +def db_engine(): + engine = create_engine(get_postgres_connection_string()) + yield engine + engine.dispose() + +@pytest.fixture() +def connection(db_engine): + connection = db_engine.connect() + yield connection + connection.close() + +@dataclass +class AlembicRunner: + connection: Connection + alembic_config: Config + inspector: Inspector + metadata: MetaData + session: scoped_session + + def reflect(self): + self.metadata.clear() + self.metadata.reflect(bind=self.connection) + self.inspector = inspect(self.connection) + + def upgrade(self, revision: str): + command.upgrade(self.alembic_config, revision) + + def downgrade(self, revision: str): + command.downgrade(self.alembic_config, revision) + + def stamp(self, revision: str): + command.stamp(self.alembic_config, revision) + +@pytest.fixture() +def alembic_runner(connection, alembic_config) -> AlembicRunner: + alembic_config.attributes["connection"] = connection + alembic_config.set_main_option( + "sqlalchemy.url", + get_postgres_connection_string() + ) + runner = AlembicRunner( + alembic_config=alembic_config, + inspector=inspect(connection), + metadata=MetaData(), + connection=connection, + session=scoped_session(sessionmaker(bind=connection)), + ) + Base.metadata.drop_all(connection) + connection.commit() + runner.stamp("base") + yield runner + runner.upgrade("head") + + + + +def test_base(alembic_runner): + table_names = alembic_runner.inspector.get_table_names() + assert table_names == [ + 'alembic_version', + ] + + alembic_runner.upgrade("d11f07224d1f") + + # Reflect the updated database state + alembic_runner.reflect() + + table_names = alembic_runner.inspector.get_table_names() + assert table_names.sort() == [ + 'batches', + 'logs', + 'missing', + 'urls', + 'duplicates', + 'alembic_version', + ].sort() + +def test_add_url_updated_at(alembic_runner): + alembic_runner.upgrade("d11f07224d1f") + + columns = [col["name"] for col in alembic_runner.inspector.get_columns("urls")] + assert "updated_at" not in columns + + alembic_runner.upgrade("a4750e7ff8e7") + + # Reflect the updated database state + alembic_runner.reflect() + + columns = [col["name"] for col in alembic_runner.inspector.get_columns("urls")] + assert "updated_at" in columns + + with alembic_runner.session() as session: + # Add a batch + batch = Batch( + strategy="test", + user_id=1, + status="complete", + total_url_count=0, + original_url_count=0, + duplicate_url_count=0 + ) + session.add(batch) + + # Add a url + url = URL(batch_id=1, url="https://example.com", url_metadata={}, outcome="success") + session.add(url) + session.flush() + session.commit() + + # alembic_runner.session.refresh() + + + + with alembic_runner.session() as session: + url = session.query(URL).first() + assert url.url == "https://example.com" + updated_1 = url.updated_at + print(updated_1) + + time.sleep(1) + # Update the url + url.url = "https://example.com/new" + session.commit() + url = session.query(URL).first() + assert url.url == "https://example.com/new" + updated_2 = url.updated_at + + # assert updated_1 < updated_2 + + # Create a new URL entry + diff --git a/tests/test_automated/integration/api/helpers/RequestValidator.py b/tests/test_automated/integration/api/helpers/RequestValidator.py index eaaab9a9..59ab4a28 100644 --- a/tests/test_automated/integration/api/helpers/RequestValidator.py +++ b/tests/test_automated/integration/api/helpers/RequestValidator.py @@ -6,12 +6,15 @@ from collector_db.DTOs.BatchInfo import BatchInfo from collector_manager.DTOs.ExampleInputDTO import ExampleInputDTO +from collector_manager.enums import CollectorType from core.DTOs.GetBatchLogsResponse import GetBatchLogsResponse from core.DTOs.GetBatchStatusResponse import GetBatchStatusResponse from core.DTOs.GetDuplicatesByBatchResponse import GetDuplicatesByBatchResponse from core.DTOs.GetURLsByBatchResponse import GetURLsByBatchResponse from core.DTOs.LabelStudioExportResponseInfo import LabelStudioExportResponseInfo from core.DTOs.MessageResponse import MessageResponse +from core.enums import BatchStatus +from util.helper_functions import update_if_not_none class ExpectedResponseInfo(BaseModel): @@ -105,9 +108,18 @@ def delete( expected_response=expected_response, **kwargs) - def get_batch_statuses(self) -> GetBatchStatusResponse: + def get_batch_statuses(self, collector_type: Optional[CollectorType] = None, status: Optional[BatchStatus] = None) -> GetBatchStatusResponse: + params = {} + update_if_not_none( + target=params, + source={ + "collector_type": collector_type.value if collector_type else None, + "status": status.value if status else None + } + ) data = self.get( - url=f"/batch" + url=f"/batch", + params=params ) return GetBatchStatusResponse(**data) diff --git a/tests/test_automated/integration/api/test_batch.py b/tests/test_automated/integration/api/test_batch.py index 92b2e72e..61c2a8b2 100644 --- a/tests/test_automated/integration/api/test_batch.py +++ b/tests/test_automated/integration/api/test_batch.py @@ -32,7 +32,7 @@ def test_get_batch_urls(api_test_helper): batch_id = ath.db_data_creator.batch() iui: InsertURLsInfo = ath.db_data_creator.urls(batch_id=batch_id, url_count=101) - response = ath.request_validator.get_batch_urls(batch_id=1, page=1) + response = ath.request_validator.get_batch_urls(batch_id=batch_id, page=1) assert len(response.urls) == 100 # Check that the first url corresponds to the first url inserted assert response.urls[0].url == iui.url_mappings[0].url @@ -41,7 +41,7 @@ def test_get_batch_urls(api_test_helper): # Check that a more limited set of urls exist - response = ath.request_validator.get_batch_urls(batch_id=1, page=2) + response = ath.request_validator.get_batch_urls(batch_id=batch_id, page=2) assert len(response.urls) == 1 # Check that this url corresponds to the last url inserted assert response.urls[0].url == iui.url_mappings[-1].url diff --git a/tests/test_automated/integration/api/test_example_collector.py b/tests/test_automated/integration/api/test_example_collector.py index 8dc0040d..d3d10330 100644 --- a/tests/test_automated/integration/api/test_example_collector.py +++ b/tests/test_automated/integration/api/test_example_collector.py @@ -1,7 +1,9 @@ import time +from unittest.mock import MagicMock from collector_db.DTOs.BatchInfo import BatchInfo from collector_manager.DTOs.ExampleInputDTO import ExampleInputDTO +from collector_manager.ExampleCollector import ExampleCollector from collector_manager.enums import CollectorType from core.DTOs.BatchStatusInfo import BatchStatusInfo from core.DTOs.GetBatchLogsResponse import GetBatchLogsResponse @@ -23,7 +25,7 @@ def test_example_collector(api_test_helper): assert batch_id is not None assert data["message"] == "Started example_collector collector." - bsr: GetBatchStatusResponse = ath.request_validator.get_batch_statuses() + bsr: GetBatchStatusResponse = ath.request_validator.get_batch_statuses(status=BatchStatus.IN_PROCESS) assert len(bsr.results) == 1 bsi: BatchStatusInfo = bsr.results[0] @@ -34,7 +36,7 @@ def test_example_collector(api_test_helper): time.sleep(2) - csr: GetBatchStatusResponse = ath.request_validator.get_batch_statuses() + csr: GetBatchStatusResponse = ath.request_validator.get_batch_statuses(collector_type=CollectorType.EXAMPLE, status=BatchStatus.COMPLETE) assert len(csr.results) == 1 bsi: BatchStatusInfo = csr.results[0] @@ -48,7 +50,7 @@ def test_example_collector(api_test_helper): assert bi.total_url_count == 2 assert bi.parameters == dto.model_dump() assert bi.strategy == "example_collector" - assert bi.user_id == 1 + assert bi.user_id is not None # Flush early to ensure logs are written ath.core.collector_manager.logger.flush_all() @@ -58,5 +60,40 @@ def test_example_collector(api_test_helper): assert len(lr.logs) > 0 +def test_example_collector_error(api_test_helper, monkeypatch): + """ + Test that when an error occurs in a collector, the batch is properly update + """ + ath = api_test_helper + + # Patch the collector to raise an exception during run_implementation + mock = MagicMock() + mock.side_effect = Exception("Collector failed!") + monkeypatch.setattr(ExampleCollector, 'run_implementation', mock) + + dto = ExampleInputDTO( + sleep_time=1 + ) + + data = ath.request_validator.example_collector( + dto=dto + ) + batch_id = data["batch_id"] + assert batch_id is not None + assert data["message"] == "Started example_collector collector." + + time.sleep(1) + + bi: BatchInfo = ath.request_validator.get_batch_info(batch_id=batch_id) + + assert bi.status == BatchStatus.ERROR + + + ath.core.core_logger.flush_all() + + gbl: GetBatchLogsResponse = ath.request_validator.get_batch_logs(batch_id=batch_id) + assert gbl.logs[-1].log == "Error: Collector failed!" + + diff --git a/tests/test_automated/integration/collector_db/test_db_client.py b/tests/test_automated/integration/collector_db/test_db_client.py index 42430299..7011de3f 100644 --- a/tests/test_automated/integration/collector_db/test_db_client.py +++ b/tests/test_automated/integration/collector_db/test_db_client.py @@ -1,3 +1,4 @@ +import time from datetime import datetime, timedelta from collector_db.DTOs.BatchInfo import BatchInfo @@ -38,16 +39,12 @@ def test_insert_urls(db_client_test): batch_id=batch_id ) - assert insert_urls_info.url_mappings == [ - URLMapping( - url="https://example.com/1", - url_id=1 - ), - URLMapping( - url="https://example.com/2", - url_id=2 - ) - ] + url_mappings = insert_urls_info.url_mappings + assert len(url_mappings) == 2 + assert url_mappings[0].url == "https://example.com/1" + assert url_mappings[1].url == "https://example.com/2" + + assert insert_urls_info.original_count == 2 assert insert_urls_info.duplicate_count == 1 @@ -80,8 +77,30 @@ def test_delete_old_logs(db_data_creator: DBDataCreator): for i in range(3): log_infos.append(LogInfo(log="test log", batch_id=batch_id, created_at=old_datetime)) db_client.insert_logs(log_infos=log_infos) - assert len(db_client.get_all_logs()) == 3 + logs = db_client.get_logs_by_batch_id(batch_id=batch_id) + assert len(logs) == 3 db_client.delete_old_logs() - logs = db_client.get_all_logs() - assert len(logs) == 0 \ No newline at end of file + logs = db_client.get_logs_by_batch_id(batch_id=batch_id) + assert len(logs) == 0 + +def test_delete_url_updated_at(db_data_creator: DBDataCreator): + batch_id = db_data_creator.batch() + url_id = db_data_creator.urls(batch_id=batch_id, url_count=1).url_mappings[0].url_id + + db_client = db_data_creator.db_client + url_info = db_client.get_urls_by_batch(batch_id=batch_id, page=1)[0] + + old_updated_at = url_info.updated_at + + + db_client.update_url( + url_info=URLInfo( + id=url_id, + url="dg", + url_metadata={"test_metadata": "test_metadata"}, + ) + ) + + url = db_client.get_urls_by_batch(batch_id=batch_id, page=1)[0] + assert url.updated_at > old_updated_at diff --git a/tests/test_automated/integration/conftest.py b/tests/test_automated/integration/conftest.py index 6534284a..5f5471ed 100644 --- a/tests/test_automated/integration/conftest.py +++ b/tests/test_automated/integration/conftest.py @@ -1,5 +1,8 @@ import pytest +from alembic import command +from alembic.config import Config +from sqlalchemy import create_engine from collector_db.helper_functions import get_postgres_connection_string from collector_db.models import Base @@ -8,13 +11,46 @@ from collector_db.DatabaseClient import DatabaseClient from core.SourceCollectorCore import SourceCollectorCore +@pytest.fixture(autouse=True, scope="session") +def setup_and_teardown(): + conn = get_postgres_connection_string() + engine = create_engine(conn) + alembic_cfg = Config("alembic.ini") + alembic_cfg.attributes["connection"] = engine.connect() + alembic_cfg.set_main_option( + "sqlalchemy.url", + get_postgres_connection_string() + ) + command.upgrade(alembic_cfg, "head") + engine.dispose() + yield @pytest.fixture def db_client_test() -> DatabaseClient: - db_client = DatabaseClient(db_url=get_postgres_connection_string()) + # Drop pre-existing table + conn = get_postgres_connection_string() + engine = create_engine(conn) + with engine.connect() as connection: + for table in reversed(Base.metadata.sorted_tables): + connection.execute(table.delete()) + connection.commit() + + # # # Run alembic to set at base + # alembic_cfg = Config("alembic.ini") + # alembic_cfg.attributes["connection"] = engine.connect() + # alembic_cfg.set_main_option( + # "sqlalchemy.url", + # get_postgres_connection_string() + # ) + # # command.stamp(alembic_cfg, "base") + # + # + # # Then upgrade to head + # command.upgrade(alembic_cfg, "head") + + db_client = DatabaseClient(db_url=conn) yield db_client db_client.engine.dispose() - Base.metadata.drop_all(db_client.engine) @pytest.fixture def test_core(db_client_test): diff --git a/tests/test_automated/integration/core/test_core_logger.py b/tests/test_automated/integration/core/test_core_logger.py index 9337ee1e..37ec3b49 100644 --- a/tests/test_automated/integration/core/test_core_logger.py +++ b/tests/test_automated/integration/core/test_core_logger.py @@ -33,15 +33,15 @@ def test_multithreaded_integration_with_live_db(db_data_creator: DBDataCreator): db_client = db_data_creator.db_client db_client.delete_all_logs() - for i in range(5): - db_data_creator.batch() + batch_ids = [db_data_creator.batch() for _ in range(5)] db_client = db_data_creator.db_client logger = CoreLogger(flush_interval=1, db_client=db_client, batch_size=10) # Simulate multiple threads logging def worker(thread_id): + batch_id = batch_ids[thread_id-1] for i in range(10): # Each thread logs 10 messages - logger.log(LogInfo(log=f"Thread-{thread_id} Log-{i}", batch_id=thread_id)) + logger.log(LogInfo(log=f"Thread-{thread_id} Log-{i}", batch_id=batch_id)) # Start multiple threads threads = [threading.Thread(target=worker, args=(i+1,)) for i in range(5)] # 5 threads @@ -51,8 +51,8 @@ def worker(thread_id): t.join() # Allow the logger to flush - time.sleep(4) logger.shutdown() + time.sleep(10) # Verify logs in the database logs = db_client.get_all_logs() diff --git a/tests/test_automated/integration/source_collectors/test_example_collector.py b/tests/test_automated/integration/source_collectors/test_example_collector.py index 4be48710..3dfcc6c8 100644 --- a/tests/test_automated/integration/source_collectors/test_example_collector.py +++ b/tests/test_automated/integration/source_collectors/test_example_collector.py @@ -14,7 +14,7 @@ def test_live_example_collector_abort(test_core: SourceCollectorCore): core = test_core db_client = core.db_client - db_client.insert_batch( + batch_id = db_client.insert_batch( BatchInfo( strategy="example", status=BatchStatus.IN_PROCESS, @@ -29,7 +29,7 @@ def test_live_example_collector_abort(test_core: SourceCollectorCore): ) collector = ExampleCollector( - batch_id=1, + batch_id=batch_id, dto=dto, logger=core.core_logger, db_client=db_client, @@ -43,5 +43,5 @@ def test_live_example_collector_abort(test_core: SourceCollectorCore): thread.join() - assert db_client.get_batch_status(1) == BatchStatus.ABORTED + assert db_client.get_batch_status(batch_id) == BatchStatus.ABORTED diff --git a/util/helper_functions.py b/util/helper_functions.py index a59786cf..ccc7d96e 100644 --- a/util/helper_functions.py +++ b/util/helper_functions.py @@ -17,4 +17,9 @@ def get_from_env(key: str): return val def base_model_list_dump(model_list: list[BaseModel]) -> list[dict]: - return [model.model_dump() for model in model_list] \ No newline at end of file + return [model.model_dump() for model in model_list] + +def update_if_not_none(target: dict, source: dict): + for key, value in source.items(): + if value is not None: + target[key] = value \ No newline at end of file