diff --git a/cdx_toolkit/cli.py b/cdx_toolkit/cli.py index c650d5c..f74610d 100644 --- a/cdx_toolkit/cli.py +++ b/cdx_toolkit/cli.py @@ -12,8 +12,8 @@ from cdx_toolkit.filter_cdx.command import run_filter_cdx from cdx_toolkit.filter_cdx.args import add_filter_cdx_args -from cdx_toolkit.filter_warc.command import run_warcer_by_cdx -from cdx_toolkit.filter_warc.args import add_warcer_by_cdx_args +from cdx_toolkit.filter_warc.command import run_repackage +from cdx_toolkit.filter_warc.args import add_repackage_args LOGGER = logging.getLogger(__name__) @@ -124,12 +124,12 @@ def main(args=None): warc.add_argument('url') warc.set_defaults(func=warcer) - warc_by_cdx = subparsers.add_parser( - 'warc_by_cdx', - help='iterate over capture content based on an CDX index file, creating a warc' + repackage = subparsers.add_parser( + 'repackage', + help='repackage WARC ranges from a CDX/SQL/CSV source into a new WARC' ) - add_warcer_by_cdx_args(warc_by_cdx) - warc_by_cdx.set_defaults(func=run_warcer_by_cdx) + add_repackage_args(repackage) + repackage.set_defaults(func=run_repackage) filter_cdx = subparsers.add_parser('filter_cdx', help='Filter CDX files based on SURT prefixes whitelist') add_filter_cdx_args(filter_cdx) diff --git a/cdx_toolkit/filter_warc/args.py b/cdx_toolkit/filter_warc/args.py index 1fa3a2e..71afa90 100644 --- a/cdx_toolkit/filter_warc/args.py +++ b/cdx_toolkit/filter_warc/args.py @@ -5,12 +5,13 @@ logger = logging.getLogger(__name__) -def add_warcer_by_cdx_args(parser: argparse.ArgumentParser): +def add_repackage_args(parser: argparse.ArgumentParser): + # --- CDX source --- parser.add_argument( '--cdx-path', type=str, default=None, - help='Path to CDX index file (local or remote, e.g. S3). Required if target source is set to `cdx`.', + help='Path to CDX index file (local or remote, e.g. S3). Used when --target-source cdx.', ) parser.add_argument( '--cdx-glob', @@ -18,46 +19,95 @@ def add_warcer_by_cdx_args(parser: argparse.ArgumentParser): default=None, help='a glob pattern for read from multiple CDX indices', ) + # --- SQL source (--target-source sql --engine athena|duckdb) --- parser.add_argument( - '--athena-hostnames', + '--engine', type=str, - nargs="+", default=None, - help=('Hostnames to filter for via Athena (whitelist). Use this OR --athena-query/' - '--athena-query-file (mutually exclusive) when target source is `athena`.'), + choices=['athena', 'duckdb'], + help='SQL engine for the columnar index. Required when --target-source sql.', ) parser.add_argument( - '--athena-query', + '--hostnames', type=str, + nargs='+', default=None, - help=('Raw Athena SQL to run instead of the hostname-based query (power users). The query ' - 'must SELECT the columns warc_filename, warc_record_offset, warc_record_length. ' - 'Mutually exclusive with --athena-hostnames and --athena-query-file.'), + help=('Exact hostnames (url_host_name, e.g. www.example.com) to filter for via the SQL ' + 'index. Combine with --domains; mutually exclusive with --query/--query-file. ' + 'Combine with the global --crawl to restrict the scan to specific crawls ' + '(strongly recommended for cost).'), ) parser.add_argument( - '--athena-query-file', + '--domains', type=str, + nargs='+', default=None, - help='Path to a file containing the raw Athena SQL (alternative to --athena-query).', + help=('Registered domains (url_host_registered_domain, e.g. example.com) to filter for via ' + 'the SQL index; also matches subdomains. Combine with --hostnames; mutually exclusive ' + 'with --query/--query-file.'), + ) + parser.add_argument( + '--query', + type=str, + default=None, + help=('Raw SQL to run instead of the hostname-based query (power users). Must SELECT the ' + 'columns warc_filename, warc_record_offset, warc_record_length. Engine-specific ' + 'dialect. Mutually exclusive with --hostnames and --query-file.'), + ) + parser.add_argument( + '--query-file', + type=str, + default=None, + help='Path to a file containing the raw SQL (alternative to --query).', ) parser.add_argument( '--athena-database', type=str, default=None, - help='Athena database. Required if target source is set to `athena`.', + help='Athena database (engine=athena). Defaults to `ccindex`.', ) parser.add_argument( '--athena-s3-output', type=str, default=None, - help='Athena S3 output location. Required if target source is set to `athena`.', + help='Athena S3 output location (engine=athena). Required for engine=athena.', + ) + parser.add_argument( + '--duckdb-index-path', + type=str, + default='s3://commoncrawl/cc-index/table/cc-main/warc/', + help='Base S3 path to the CC columnar index parquet (engine=duckdb).', ) parser.add_argument( - '--confirm-athena-cost', + '--confirm-cost', action='store_true', - help=('Skip the Athena cost-confirmation prompt and run even unpartitioned / large-scan ' + help=('Skip the cost-confirmation prompt and run even unpartitioned / large-scan SQL ' 'queries. Athena bills per TB scanned; restrict with --crawl to reduce cost.'), ) + # --- CSV source --- + parser.add_argument( + '--csv-path', + type=str, + default=None, + help='Path to a range-jobs CSV/TSV (local or remote). Used when --target-source csv.', + ) + # --- Range-jobs materialization (any source) --- + parser.add_argument( + '--range-jobs-output', + type=str, + default=None, + help='If set, write each generated RangeJob to this CSV (filename,offset,length by default).', + ) + parser.add_argument( + '--no-fetch', + action='store_true', + help='Only generate range jobs (write --range-jobs-output); skip fetching/writing WARCs.', + ) + parser.add_argument( + '--csv-self-contained', + action='store_true', + help='Write full URLs (url,offset,length) to --range-jobs-output instead of relative filenames.', + ) parser.add_argument('--prefix', default='TEST', help='prefix for the output warc filename') parser.add_argument( '--subprefix', @@ -126,8 +176,9 @@ def add_warcer_by_cdx_args(parser: argparse.ArgumentParser): '--target-source', action='store', default='cdx', - help=('Source from that the filter targets are loaded (available options: `cdx`, `athena`; ' - 'defaults to `cdx`). For `athena`, use the global --crawl to restrict the scan to ' - 'specific crawls (strongly recommended; Athena bills per TB scanned).'), + choices=['cdx', 'sql', 'csv'], + help=('Where range jobs come from: `cdx` (index files), `sql` (columnar index via ' + '--engine athena|duckdb), or `csv` (a range-jobs CSV). Defaults to `cdx`. For `sql`, ' + 'use the global --crawl to restrict the scan to specific crawls (recommended for cost).'), ) return parser diff --git a/cdx_toolkit/filter_warc/athena_job_generator.py b/cdx_toolkit/filter_warc/athena_job_generator.py deleted file mode 100644 index c4c4323..0000000 --- a/cdx_toolkit/filter_warc/athena_job_generator.py +++ /dev/null @@ -1,251 +0,0 @@ -import asyncio -import logging -import re -import time -from typing import Any, Iterable, List, Optional - -from cdx_toolkit.filter_warc.data_classes import RangeJob - - -logger = logging.getLogger(__name__) - -# Required output columns that any Athena query (built or raw) must provide. -REQUIRED_RESULT_COLUMNS = ('warc_filename', 'warc_record_offset', 'warc_record_length') - -# Athena pricing is ~$5 per TB scanned (used for the post-run cost estimate). -ATHENA_USD_PER_TB = 5.0 - -# Hostnames, TLDs and crawl names only ever contain these characters. Validating -# against this set both prevents SQL injection and catches malformed input early. -_SQL_LITERAL_RE = re.compile(r'^[A-Za-z0-9.\-]+$') - - -def escape_sql_literal(value: str) -> str: - """Validate and quote a value for safe inclusion in an Athena SQL string literal. - - Only hostname/TLD/crawl-name characters (letters, digits, dot, hyphen) are - allowed, so the result cannot break out of the quotes or inject SQL.""" - if not isinstance(value, str) or not _SQL_LITERAL_RE.match(value): - raise ValueError( - f'Invalid value for Athena query literal: {value!r} ' - '(allowed characters: letters, digits, dot, hyphen)' - ) - return "'" + value + "'" - - -def build_athena_query( - url_host_names: List[str], - crawls: Optional[List[str]] = None, - limit: int = 0, - table: str = 'ccindex', -) -> str: - """Build the Athena SQL returning warc_filename/offset/length for the hostnames. - - CommonCrawl provides an index via AWS Athena that we can use to find the file - names, offsets, and byte lengths for WARC filtering. See - https://commoncrawl.org/blog/index-to-warc-files-and-urls-in-columnar-format - - If `crawls` is a non-empty list of crawl names (e.g. ['CC-MAIN-2025-33']), a - `crawl IN (...)` partition filter is added -- this is the main lever for - reducing Athena scan cost.""" - if not url_host_names: - raise ValueError('build_athena_query requires at least one hostname') - - tlds = sorted({url.split('.')[-1] for url in url_host_names}) - query_tlds = ' OR '.join(f'url_host_tld = {escape_sql_literal(tld)}' for tld in tlds) - query_hostnames = ' OR '.join(f'url_host_name = {escape_sql_literal(h)}' for h in url_host_names) - - where_clauses = [ - "subset = 'warc'", - f'({query_tlds}) -- help the query optimizer', - f'({query_hostnames})', - ] - # TODO wire --from/--to into a fetch_time BETWEEN ... clause here - if crawls: - crawl_in = ', '.join(escape_sql_literal(c) for c in crawls) - where_clauses.append(f'crawl IN ({crawl_in})') - - where_sql = '\n AND '.join(where_clauses) - query_limit = f'\n LIMIT {limit}' if limit > 0 else '' - - return f""" - SELECT - warc_filename, warc_record_offset, warc_record_length - FROM {table} - WHERE {where_sql}{query_limit}""" - - -def validate_result_columns(column_names) -> None: - """Raise a clear error if the query result lacks the required columns.""" - missing = [c for c in REQUIRED_RESULT_COLUMNS if c not in (column_names or [])] - if missing: - raise ValueError( - 'Athena query result is missing required columns: ' + ', '.join(missing) + - '. The query (including a raw --athena-query) must SELECT ' + - ', '.join(REQUIRED_RESULT_COLUMNS) + '.' - ) - - -def join_warc_url(prefix: Optional[str], warc_filename: str) -> str: - """Join a download prefix with a warc_filename robustly. - - - if warc_filename is already absolute (contains '://'), return it unchanged - (supports custom queries whose warc_filename is a full s3://-/https:// URL); - - if prefix is empty/None, return warc_filename unchanged; - - otherwise join with exactly one '/' (no double slash, no missing slash).""" - if '://' in warc_filename: - return warc_filename - if not prefix: - return warc_filename - return prefix.rstrip('/') + '/' + warc_filename.lstrip('/') - - -def run_athena_query(client, query: str, database: str, s3_output_location: str, max_wait_time: int = 300) -> str: - """Start an Athena query and block until it completes; return the execution id. - - Raises if the query does not reach the SUCCEEDED state. If the wait is - interrupted (Ctrl-C / cancellation) or times out, the query is cancelled - server-side so we don't keep paying for a scan whose results we'll never read.""" - logger.info('Executing Athena query: %s', query) - - response = client.start_query_execution( - QueryString=query, - QueryExecutionContext={'Database': database}, - ResultConfiguration={'OutputLocation': s3_output_location}, - ) - - query_execution_id = response['QueryExecutionId'] - logger.info('Query execution started. ID: %s', query_execution_id) - - try: - status = _wait_for_query_completion(client, query_execution_id, max_wait_time) - except BaseException: - # Ctrl-C, asyncio cancellation, or timeout: stop the query server-side - # to bound the Athena scan cost, then propagate the original exception. - _stop_query(client, query_execution_id) - raise - - if status != 'SUCCEEDED': - raise Exception(f'Query failed with status: {status}') - - return query_execution_id - - -def _stop_query(client, query_execution_id: str) -> None: - """Best-effort cancellation of a running Athena query.""" - logger.warning('Cancelling Athena query %s ...', query_execution_id) - try: - client.stop_query_execution(QueryExecutionId=query_execution_id) - logger.warning('Athena query %s cancelled', query_execution_id) - except Exception as e: # pragma: no cover - best-effort cleanup - logger.warning('Failed to cancel Athena query %s: %r', query_execution_id, e) - - -def report_query_cost(client, query_execution_id: str) -> None: - """Log the bytes scanned and an estimated USD cost for a completed query.""" - try: - response = client.get_query_execution(QueryExecutionId=query_execution_id) - scanned = response['QueryExecution'].get('Statistics', {}).get('DataScannedInBytes') - except Exception as e: # pragma: no cover - best-effort reporting - logger.debug('unable to read Athena query statistics: %r', e) - return - - if scanned is None: - return - - gb = scanned / 1e9 - usd = scanned / 1e12 * ATHENA_USD_PER_TB - logger.info('Athena scanned %.2f GB, estimated cost ~$%.4f', gb, usd) - - -async def get_range_jobs_from_athena( - client, - query: str, - database: str, - s3_output_location: str, - job_queue: asyncio.Queue, - queue_stop_object: Any, - warc_download_prefix: str, - num_fetchers: int, - max_wait_time: int = 300, -) -> int: - """Execute a prepared Athena query and enqueue a RangeJob per result row. - - The query string is built and validated by the caller (see build_athena_query - and cdx_toolkit.filter_warc.command). This function only executes it, maps the - results to RangeJob objects, pushes them to the asyncio queue, signals the - fetchers to stop, and logs the scan cost.""" - count = 0 - - query_execution_id = run_athena_query(client, query, database, s3_output_location, max_wait_time) - - for range_job in iter_range_jobs(client, query_execution_id, warc_download_prefix): - await job_queue.put(range_job) - count += 1 - - report_query_cost(client, query_execution_id) - - # Signal fetchers to stop - for _ in range(num_fetchers): - await job_queue.put(queue_stop_object) - - logger.info('Athena query enqueued %d jobs', count) - - return count - - -def _wait_for_query_completion(client, query_execution_id: str, max_wait_time: int) -> str: - """Wait for query to complete and return final status""" - start_time = time.time() - - while time.time() - start_time < max_wait_time: - response = client.get_query_execution(QueryExecutionId=query_execution_id) - - status = response['QueryExecution']['Status']['State'] - logger.info(f'Query status: {status}') - - if status in ['SUCCEEDED', 'FAILED', 'CANCELLED']: - if status == 'FAILED': - error_reason = response['QueryExecution']['Status'].get('StateChangeReason', 'Unknown error') - logger.info(f'Query failed: {error_reason}') - return status - - time.sleep(2) - - raise TimeoutError(f'Query did not complete within {max_wait_time} seconds') - - -def iter_range_jobs(client, query_execution_id: str, warc_download_prefix: str) -> Iterable[RangeJob]: - """Retrieve query results and convert each row to a RangeJob""" - # Get query results - paginator = client.get_paginator('get_query_results') - page_iterator = paginator.paginate(QueryExecutionId=query_execution_id) - column_names = None - - for page in page_iterator: - rows = page['ResultSet']['Rows'] - - # Get column names from first page - if column_names is None and rows: - column_names = [col['VarCharValue'] for col in rows[0]['Data']] - validate_result_columns(column_names) - rows = rows[1:] # Skip header row - - # Process data rows - for row in rows: - row_data = [] - for cell in row['Data']: - value = cell.get('VarCharValue', None) - row_data.append(value) - - row = dict(zip(column_names, row_data)) - - warc_url = join_warc_url(warc_download_prefix, row['warc_filename']) - - yield RangeJob(url=warc_url, offset=int(row['warc_record_offset']), length=int(row['warc_record_length'])) - - -def get_databases(client) -> list: - """Get list of available databases""" - response = client.list_databases(CatalogName='AwsDataCatalog') - return [db['Name'] for db in response['DatabaseList']] diff --git a/cdx_toolkit/filter_warc/cdx_utils.py b/cdx_toolkit/filter_warc/cdx_utils.py index 6b0c584..45f32a2 100644 --- a/cdx_toolkit/filter_warc/cdx_utils.py +++ b/cdx_toolkit/filter_warc/cdx_utils.py @@ -27,7 +27,7 @@ def get_index_as_string_from_path( return f.read() -def read_cdx_line(line: str, warc_download_prefix: str) -> Tuple[str, int, int]: +def read_cdx_line(line: str, warc_download_prefix: str) -> Tuple[str, int, int, str]: cols = line.split(' ', maxsplit=2) if len(cols) == 3: @@ -49,10 +49,10 @@ def read_cdx_line(line: str, warc_download_prefix: str) -> Tuple[str, int, int]: warc_url = warc_download_prefix + '/' + filename - return (warc_url, offset, length) + return (warc_url, offset, length, filename) -def iter_cdx_index_from_path(index_path: str, warc_download_prefix: str) -> Iterable[Tuple[str, int, int]]: +def iter_cdx_index_from_path(index_path: str, warc_download_prefix: str) -> Iterable[Tuple[str, int, int, str]]: """ Iterate CDX records from a file path (gzipped; local or remote). """ diff --git a/cdx_toolkit/filter_warc/command.py b/cdx_toolkit/filter_warc/command.py index 2328f64..8a2d187 100644 --- a/cdx_toolkit/filter_warc/command.py +++ b/cdx_toolkit/filter_warc/command.py @@ -1,7 +1,5 @@ -from cdx_toolkit.filter_warc.cdx_utils import get_cdx_paths from cdx_toolkit.filter_warc.warc_filter import WARCFilter -from cdx_toolkit.filter_warc.athena_job_generator import build_athena_query -from cdx_toolkit.commoncrawl import normalize_crawl, get_cc_endpoints, match_cc_crawls +from cdx_toolkit.filter_warc.sources import make_source from cdx_toolkit.utils import get_version @@ -14,116 +12,54 @@ logger = logging.getLogger(__name__) -# A built Athena query restricted to at most this many crawls is considered cheap -# enough to run without a cost-confirmation prompt. +# A SQL query restricted to at most this many crawls is considered cheap enough to +# run without a cost-confirmation prompt. LARGE_CRAWL_SET_THRESHOLD = 10 -# Default Common Crawl index mirror used to resolve --crawl to concrete crawl names. -CC_INDEX_MIRROR = 'https://index.commoncrawl.org/' +def confirm_cost(estimate, confirmed) -> None: + """Prompt before running a potentially expensive index scan. -def _endpoint_to_crawl_name(endpoint: str) -> str: - """Turn a collinfo cdx-api endpoint into its crawl name. - - e.g. 'https://index.commoncrawl.org/CC-MAIN-2025-33-index' -> 'CC-MAIN-2025-33'.""" - name = endpoint.rstrip('/').split('/')[-1] - if name.endswith('-index'): - name = name[:-len('-index')] - return name - - -def _resolve_crawl_names(crawl_arg) -> list: - """Resolve a --crawl value to concrete CC-MAIN crawl names for the partition filter. - - Reuses the CDX path's helpers so `--crawl` accepts the same forms (comma-separated - names or an integer for the most recent N crawls).""" - crawls = normalize_crawl([crawl_arg]) - raw_index_list = get_cc_endpoints(CC_INDEX_MIRROR) - matched = match_cc_crawls(crawls, raw_index_list) - return [_endpoint_to_crawl_name(ep) for ep in matched] - - -def resolve_athena_query(args): - """Validate the Athena args and return (sql, n_crawls). - - n_crawls is the number of crawls the query is restricted to, or None when the - query is unrestricted (scans all crawls) or its pruning cannot be verified - (raw --athena-query/--athena-query-file).""" - raw_sql = args.athena_query - if args.athena_query_file: - if raw_sql: - raise ValueError('--athena-query and --athena-query-file are mutually exclusive') - with open(args.athena_query_file) as f: - raw_sql = f.read() - - if raw_sql and args.athena_hostnames: - raise ValueError('--athena-query/--athena-query-file are mutually exclusive with --athena-hostnames') - if not raw_sql and not args.athena_hostnames: - raise ValueError('athena target requires either --athena-hostnames or --athena-query/--athena-query-file') - - if not args.athena_database: - raise ValueError('--athena-database is required for target source `athena`') - if not args.athena_s3_output: - raise ValueError('--athena-s3-output is required for target source `athena`') - - if raw_sql: - # Crawl-partition pruning of a raw query cannot be verified -> treat as unbounded. - return raw_sql, None - - # Guided/built path - limit = 0 if args.limit is None else args.limit - if args.crawl: - crawl_names = _resolve_crawl_names(args.crawl) - sql = build_athena_query(args.athena_hostnames, crawls=crawl_names, limit=limit) - return sql, len(crawl_names) - - sql = build_athena_query(args.athena_hostnames, limit=limit) - return sql, None - - -def confirm_athena_cost(n_crawls, confirmed) -> None: - """Prompt before running a potentially expensive Athena query. - - A built query restricted to <= LARGE_CRAWL_SET_THRESHOLD crawls runs without a - prompt. Otherwise (no crawl filter, a large crawl set, or unverifiable raw SQL) - we confirm interactively, abort in non-interactive sessions, unless `confirmed` - (--confirm-athena-cost) is set.""" - if confirmed: + `estimate` is a CostEstimate (or None for sources that never bill, e.g. cdx/csv). + A scan bounded to <= LARGE_CRAWL_SET_THRESHOLD crawls runs without a prompt. + Otherwise (no crawl filter, a large crawl set, or unverifiable raw SQL) we confirm + interactively, abort in non-interactive sessions, unless `confirmed` (--confirm-cost).""" + if confirmed or estimate is None: return + + n_crawls = estimate.n_crawls if n_crawls is not None and n_crawls <= LARGE_CRAWL_SET_THRESHOLD: return + engine = estimate.engine if n_crawls is None: - reason = ('This Athena query is not restricted to specific crawls (or uses custom SQL whose ' - 'crawl partition pruning could not be verified) and may scan ALL crawls.') + reason = (f'This {engine} query is not restricted to specific crawls (or uses custom SQL ' + 'whose crawl partition pruning could not be verified) and may scan ALL crawls.') else: - reason = (f'This Athena query is restricted to {n_crawls} crawls (more than ' + reason = (f'This {engine} query is restricted to {n_crawls} crawls (more than ' f'{LARGE_CRAWL_SET_THRESHOLD}) and may scan a large amount of data.') if not sys.stdin.isatty(): raise SystemExit( - reason + ' Refusing to run a potentially expensive Athena scan in non-interactive mode. ' - 'Restrict with --crawl (<=10 crawls), or pass --confirm-athena-cost.' + reason + ' Refusing to run a potentially expensive scan in non-interactive mode. ' + 'Restrict with --crawl (<=10 crawls), or pass --confirm-cost.' ) logger.warning(reason) - answer = input('Athena bills per TB scanned. Proceed? [y/N] ') + answer = input('Index SQL scans can be expensive (Athena bills per TB scanned). Proceed? [y/N] ') if answer.strip().lower() not in ('y', 'yes'): raise SystemExit('Aborted by user.') -def run_warcer_by_cdx(args, cmdline): - """Like warcer but fetches WARC records based on one or more CDX index files. - - The CDX files can be filtered using the `filter_cdx` commands based a given URL/SURT list. +def run_repackage(args, cmdline): + """Repackage WARC records from a pluggable range-job source (cdx / sql / csv). Approach: - - Iterate over one or more CDX files to extract capture object (file, offset, length) - - Fetch WARC record based on capture object - - Write to new WARC file including metadata records with index. - - The CDX metadata record is written to the WARC directly before for response records that matches to the CDX. + - Generate RangeJobs (WARC file + byte range) from the selected source. + - Optionally materialize them to a CSV (--range-jobs-output; --no-fetch to skip fetching). + - Fetch each WARC record and write a new WARC, including metadata records. """ - logger.info('Filtering WARC files based on CDX') + logger.info('Repackaging WARC files (target source: %s)', args.target_source) # Start timing start_time = time.time() @@ -142,7 +78,7 @@ def run_warcer_by_cdx(args, cmdline): 'isPartOf': ispartof, 'description': args.description if args.description - else 'warc extraction based on CDX generated with: ' + cmdline, + else 'warc extraction generated with: ' + cmdline, 'format': 'WARC file version 1.0', } if args.creator: @@ -154,33 +90,22 @@ def run_warcer_by_cdx(args, cmdline): log_every_n = args.log_every_n limit = 0 if args.limit is None else args.limit prefix_path = str(args.prefix) - prefix_fs, prefix_fs_path = fsspec.url_to_fs(prefix_path) - # make sure the base dir exists - prefix_fs.makedirs(prefix_fs._parent(prefix_fs_path), exist_ok=True) + # Build the source (validates source/engine/query options) and confirm scan cost + # up front (synchronously) before launching the pipeline. + source = make_source(args, warc_download_prefix=args.warc_download_prefix, record_limit=limit) + confirm_cost(source.estimate_cost(), args.confirm_cost) - # target source handling - athena_query = None - if args.target_source == 'cdx': - cdx_paths = get_cdx_paths( - args.cdx_path, - args.cdx_glob, - ) - elif args.target_source == "athena": - cdx_paths = None - # Build/validate the Athena query up front (synchronously) so we can warn - # about expensive (unpartitioned / large) scans before launching the pipeline. - athena_query, n_crawls = resolve_athena_query(args) - confirm_athena_cost(n_crawls, args.confirm_athena_cost) - else: - raise ValueError(f'Invalid target source specified: {args.target_source} (available: cdx, athena)') + # make sure the output base dir exists (only needed when actually writing WARCs) + if not args.no_fetch: + prefix_fs, prefix_fs_path = fsspec.url_to_fs(prefix_path) + prefix_fs.makedirs(prefix_fs._parent(prefix_fs_path), exist_ok=True) warc_filter = WARCFilter( - target_source=args.target_source, - cdx_paths=cdx_paths, - athena_database=args.athena_database, - athena_s3_output_location=args.athena_s3_output, - athena_query=athena_query, + source=source, + range_jobs_output=args.range_jobs_output, + no_fetch=args.no_fetch, + csv_self_contained=args.csv_self_contained, prefix_path=prefix_path, writer_info=info, writer_subprefix=args.subprefix, @@ -190,7 +115,6 @@ def run_warcer_by_cdx(args, cmdline): warc_download_prefix=args.warc_download_prefix, n_parallel=n_parallel, max_file_size=args.size, - # writer_kwargs=writer_kwargs, ) records_n = warc_filter.filter() diff --git a/cdx_toolkit/filter_warc/data_classes.py b/cdx_toolkit/filter_warc/data_classes.py index 8f36325..40f190e 100644 --- a/cdx_toolkit/filter_warc/data_classes.py +++ b/cdx_toolkit/filter_warc/data_classes.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from cdx_toolkit.filter_warc.s3_utils import is_s3_url, parse_s3_uri, with_retries -from typing import Tuple +from typing import Optional, Tuple from cdx_toolkit.myrequests import myrequests_get @@ -47,6 +47,10 @@ class RangeJob: offset: int length: int records_count: int = 1 + # Relative WARC filename (e.g. crawl-data/...warc.gz) as known by the source, + # used when materializing a non-self-contained range-jobs CSV. `url` stays + # authoritative for fetching. + filename: Optional[str] = None def is_s3(self): return is_s3_url(self.url) diff --git a/cdx_toolkit/filter_warc/sources/__init__.py b/cdx_toolkit/filter_warc/sources/__init__.py new file mode 100644 index 0000000..d1208cf --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/__init__.py @@ -0,0 +1,4 @@ +from cdx_toolkit.filter_warc.sources.base import RangeJobSource, CostEstimate +from cdx_toolkit.filter_warc.sources.factory import make_source + +__all__ = ['RangeJobSource', 'CostEstimate', 'make_source'] diff --git a/cdx_toolkit/filter_warc/sources/athena.py b/cdx_toolkit/filter_warc/sources/athena.py new file mode 100644 index 0000000..203048d --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/athena.py @@ -0,0 +1,177 @@ +import logging +import time +from typing import Iterator, Iterable, Optional + +from cdx_toolkit.filter_warc.data_classes import RangeJob +from cdx_toolkit.filter_warc.sources.base import RangeJobSource, CostEstimate +from cdx_toolkit.filter_warc.sources.sql_base import validate_result_columns, join_warc_url + + +logger = logging.getLogger(__name__) + +# Athena pricing is ~$5 per TB scanned (used for the post-run cost estimate). +ATHENA_USD_PER_TB = 5.0 + + +class AthenaSource(RangeJobSource): + """RangeJobs from a query against the CC columnar index via AWS Athena.""" + + def __init__( + self, + *, + query: str, + database: str, + s3_output_location: str, + warc_download_prefix: Optional[str], + n_crawls: Optional[int] = None, + region_name: str = 'us-east-1', + max_wait_time: int = 300, + ): + self.query = query + self.database = database + self.s3_output_location = s3_output_location + self.warc_download_prefix = warc_download_prefix + self._n_crawls = n_crawls + self.region_name = region_name + self.max_wait_time = max_wait_time + + def estimate_cost(self) -> CostEstimate: + return CostEstimate(n_crawls=self._n_crawls, engine='athena') + + def _make_client(self): + import boto3 + from botocore.config import Config + + config = Config( + region_name=self.region_name, + read_timeout=60, + retries={'max_attempts': 3, 'mode': 'adaptive'}, + ) + return boto3.client('athena', config=config) + + def iter_range_jobs(self) -> Iterator[RangeJob]: + client = self._make_client() + query_execution_id = run_athena_query( + client, self.query, self.database, self.s3_output_location, self.max_wait_time + ) + try: + for job in iter_range_jobs(client, query_execution_id, self.warc_download_prefix): + yield job + finally: + report_query_cost(client, query_execution_id) + + +def run_athena_query(client, query: str, database: str, s3_output_location: str, max_wait_time: int = 300) -> str: + """Start an Athena query and block until it completes; return the execution id. + + Raises if the query does not reach the SUCCEEDED state. If the wait is + interrupted (Ctrl-C / cancellation) or times out, the query is cancelled + server-side so we don't keep paying for a scan whose results we'll never read.""" + logger.info('Executing Athena query: %s', query) + + response = client.start_query_execution( + QueryString=query, + QueryExecutionContext={'Database': database}, + ResultConfiguration={'OutputLocation': s3_output_location}, + ) + + query_execution_id = response['QueryExecutionId'] + logger.info('Query execution started. ID: %s', query_execution_id) + + try: + status = _wait_for_query_completion(client, query_execution_id, max_wait_time) + except BaseException: + # Ctrl-C, asyncio cancellation, or timeout: stop the query server-side + # to bound the Athena scan cost, then propagate the original exception. + _stop_query(client, query_execution_id) + raise + + if status != 'SUCCEEDED': + raise Exception(f'Query failed with status: {status}') + + return query_execution_id + + +def _stop_query(client, query_execution_id: str) -> None: + """Best-effort cancellation of a running Athena query.""" + logger.warning('Cancelling Athena query %s ...', query_execution_id) + try: + client.stop_query_execution(QueryExecutionId=query_execution_id) + logger.warning('Athena query %s cancelled', query_execution_id) + except Exception as e: # pragma: no cover - best-effort cleanup + logger.warning('Failed to cancel Athena query %s: %r', query_execution_id, e) + + +def report_query_cost(client, query_execution_id: str) -> None: + """Log the bytes scanned and an estimated USD cost for a completed query.""" + try: + response = client.get_query_execution(QueryExecutionId=query_execution_id) + scanned = response['QueryExecution'].get('Statistics', {}).get('DataScannedInBytes') + except Exception as e: # pragma: no cover - best-effort reporting + logger.debug('unable to read Athena query statistics: %r', e) + return + + if scanned is None: + return + + gb = scanned / 1e9 + usd = scanned / 1e12 * ATHENA_USD_PER_TB + logger.info('Athena scanned %.2f GB, estimated cost ~$%.4f', gb, usd) + + +def _wait_for_query_completion(client, query_execution_id: str, max_wait_time: int) -> str: + """Wait for query to complete and return final status""" + start_time = time.time() + + while time.time() - start_time < max_wait_time: + response = client.get_query_execution(QueryExecutionId=query_execution_id) + + status = response['QueryExecution']['Status']['State'] + logger.info(f'Query status: {status}') + + if status in ['SUCCEEDED', 'FAILED', 'CANCELLED']: + if status == 'FAILED': + error_reason = response['QueryExecution']['Status'].get('StateChangeReason', 'Unknown error') + logger.info(f'Query failed: {error_reason}') + return status + + time.sleep(2) + + raise TimeoutError(f'Query did not complete within {max_wait_time} seconds') + + +def iter_range_jobs(client, query_execution_id: str, warc_download_prefix: Optional[str]) -> Iterable[RangeJob]: + """Retrieve query results and convert each row to a RangeJob""" + paginator = client.get_paginator('get_query_results') + page_iterator = paginator.paginate(QueryExecutionId=query_execution_id) + column_names = None + + for page in page_iterator: + rows = page['ResultSet']['Rows'] + + # Get column names from first page + if column_names is None and rows: + column_names = [col['VarCharValue'] for col in rows[0]['Data']] + validate_result_columns(column_names) + rows = rows[1:] # Skip header row + + # Process data rows + for row in rows: + row_data = [cell.get('VarCharValue', None) for cell in row['Data']] + row = dict(zip(column_names, row_data)) + + warc_filename = row['warc_filename'] + warc_url = join_warc_url(warc_download_prefix, warc_filename) + + yield RangeJob( + url=warc_url, + offset=int(row['warc_record_offset']), + length=int(row['warc_record_length']), + filename=warc_filename, + ) + + +def get_databases(client) -> list: + """Get list of available databases""" + response = client.list_databases(CatalogName='AwsDataCatalog') + return [db['Name'] for db in response['DatabaseList']] diff --git a/cdx_toolkit/filter_warc/sources/base.py b/cdx_toolkit/filter_warc/sources/base.py new file mode 100644 index 0000000..f833cce --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/base.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod +from typing import Iterator, NamedTuple, Optional + +from cdx_toolkit.filter_warc.data_classes import RangeJob + + +class CostEstimate(NamedTuple): + """Describes the scan a source is about to run, for the cost-confirmation guard. + + n_crawls is the number of crawls the scan is bounded to, or None when the scan + is unbounded / its pruning cannot be verified (e.g. raw SQL, all-crawls glob).""" + + n_crawls: Optional[int] + engine: str + + +class RangeJobSource(ABC): + """A source of RangeJobs for the repackage pipeline. + + A source owns its own stage-1 resource (Athena client, DuckDB connection, or an + fsspec file handle) and yields RangeJobs synchronously; the pipeline orchestrator + bridges the sync generator into the async fetch/write stages, and owns queueing, + the record limit, counting, and stop-sentinel emission.""" + + def estimate_cost(self) -> Optional[CostEstimate]: + """Return a CostEstimate for the cost guard, or None for sources that never + incur a per-scan charge (cdx files, csv).""" + return None + + @abstractmethod + def iter_range_jobs(self) -> Iterator[RangeJob]: + """Yield RangeJobs. Implementations open their own client/connection lazily + and close it in a finally. Each RangeJob carries `url` (authoritative for + fetching) and, where known, the relative `filename`.""" + raise NotImplementedError diff --git a/cdx_toolkit/filter_warc/sources/cdx.py b/cdx_toolkit/filter_warc/sources/cdx.py new file mode 100644 index 0000000..bb9356d --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/cdx.py @@ -0,0 +1,28 @@ +import logging +from typing import Iterator, List + +from cdx_toolkit.filter_warc.cdx_utils import iter_cdx_index_from_path +from cdx_toolkit.filter_warc.data_classes import RangeJob +from cdx_toolkit.filter_warc.sources.base import RangeJobSource + + +logger = logging.getLogger(__name__) + + +class CdxSource(RangeJobSource): + """RangeJobs read from one or more CDX index files (local or remote via fsspec).""" + + def __init__(self, cdx_paths: List[str], warc_download_prefix: str): + self.cdx_paths = cdx_paths + self.warc_download_prefix = warc_download_prefix + + def iter_range_jobs(self) -> Iterator[RangeJob]: + for index_path in self.cdx_paths: + try: + for warc_url, offset, length, filename in iter_cdx_index_from_path( + index_path, self.warc_download_prefix + ): + yield RangeJob(url=warc_url, offset=offset, length=length, filename=filename) + except Exception as e: + # Preserve the previous behaviour of skipping a bad index file. + logger.error('Failed to read CDX index from %s: %s', index_path, e) diff --git a/cdx_toolkit/filter_warc/sources/csv.py b/cdx_toolkit/filter_warc/sources/csv.py new file mode 100644 index 0000000..f9ac5aa --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/csv.py @@ -0,0 +1,90 @@ +import csv +import logging +from typing import Iterator, Optional + +import fsspec + +from cdx_toolkit.filter_warc.data_classes import RangeJob +from cdx_toolkit.filter_warc.sources.base import RangeJobSource +from cdx_toolkit.filter_warc.sources.sql_base import join_warc_url + + +logger = logging.getLogger(__name__) + +FILENAME_FIELDS = ['filename', 'offset', 'length'] +URL_FIELDS = ['url', 'offset', 'length'] + + +class RangeJobCsvWriter: + """Write RangeJobs to a CSV (local or remote via fsspec). + + Default mode writes the relative `filename` column (the consumer prepends the + WARC download prefix); `self_contained` mode writes the full `url` column.""" + + def __init__(self, path: str, self_contained: bool = False): + self.path = path + self.self_contained = self_contained + self._fields = URL_FIELDS if self_contained else FILENAME_FIELDS + self._ctx = fsspec.open(path, 'wt', newline='') + self._fh = self._ctx.__enter__() + self._writer = csv.DictWriter(self._fh, fieldnames=self._fields) + self._writer.writeheader() + + def write(self, job: RangeJob) -> None: + if self.self_contained: + row = {'url': job.url, 'offset': job.offset, 'length': job.length} + else: + if job.filename is None: + raise ValueError( + 'cannot write a non-self-contained range-jobs CSV: RangeJob.filename is ' + 'missing; pass --csv-self-contained to write full URLs instead' + ) + row = {'filename': job.filename, 'offset': job.offset, 'length': job.length} + self._writer.writerow(row) + + def close(self) -> None: + if self._ctx is not None: + self._ctx.__exit__(None, None, None) + self._ctx = None + self._fh = None + + +class CsvSource(RangeJobSource): + """RangeJobs read from a CSV/TSV. + + Mode is auto-detected from the header: a `url` column => self-contained (used + as-is); a `filename` column => the WARC download prefix is prepended. TSV is + detected from a `.tsv`/`.tsv.gz` extension; `.gz` inputs are decompressed.""" + + def __init__(self, path: str, warc_download_prefix: Optional[str]): + self.path = path + self.warc_download_prefix = warc_download_prefix + + def iter_range_jobs(self) -> Iterator[RangeJob]: + path = str(self.path) + delimiter = '\t' if path.endswith(('.tsv', '.tsv.gz')) else ',' + compression = 'gzip' if path.endswith('.gz') else None + + with fsspec.open(self.path, 'rt', newline='', compression=compression) as fh: + reader = csv.DictReader(fh, delimiter=delimiter) + fields = set(reader.fieldnames or []) + if 'url' in fields: + mode_url = True + elif 'filename' in fields: + mode_url = False + else: + raise ValueError( + f'range-jobs CSV {self.path} must have a `url` or `filename` column ' + f'(got header: {reader.fieldnames})' + ) + + for row in reader: + offset = int(row['offset']) + length = int(row['length']) + if mode_url: + url = row['url'] + filename = None + else: + filename = row['filename'] + url = join_warc_url(self.warc_download_prefix, filename) + yield RangeJob(url=url, offset=offset, length=length, filename=filename) diff --git a/cdx_toolkit/filter_warc/sources/duckdb.py b/cdx_toolkit/filter_warc/sources/duckdb.py new file mode 100644 index 0000000..2513bd2 --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/duckdb.py @@ -0,0 +1,113 @@ +import logging +from typing import Iterator, List, Optional + +from cdx_toolkit.filter_warc.data_classes import RangeJob +from cdx_toolkit.filter_warc.sources.base import RangeJobSource, CostEstimate +from cdx_toolkit.filter_warc.sources.sql_base import build_sql, validate_result_columns, join_warc_url + + +logger = logging.getLogger(__name__) + +try: + import duckdb + _HAS_DUCKDB = True +except ImportError: # pragma: no cover - exercised in minimal installs + duckdb = None + _HAS_DUCKDB = False + + +def _build_from_clause(index_path: str, crawls: Optional[List[str]]) -> str: + """Build a read_parquet(...) FROM expression over the CC columnar index. + + When crawls are given we glob only those crawl partitions (the pruning lever); + otherwise we glob every crawl (expensive -> the cost guard fires).""" + base = index_path.rstrip('/') + if crawls: + globs = [f"'{base}/crawl={c}/subset=warc/*.parquet'" for c in crawls] + else: + globs = [f"'{base}/crawl=*/subset=warc/*.parquet'"] + return f"read_parquet([{', '.join(globs)}], hive_partitioning=true)" + + +class DuckDbSource(RangeJobSource): + """RangeJobs from a query against the CC columnar index via DuckDB (read_parquet on S3). + + Reads the public CommonCrawl parquet directly; AWS region/credentials come from + the environment (the public bucket is readable with valid credentials).""" + + def __init__( + self, + *, + query: Optional[str] = None, + hostnames: Optional[List[str]] = None, + domains: Optional[List[str]] = None, + crawls: Optional[List[str]] = None, + index_path: str, + warc_download_prefix: Optional[str], + limit: int = 0, + region_name: str = 'us-east-1', + ): + self.raw_query = query + self.hostnames = hostnames + self.domains = domains + self.crawls = crawls + self.index_path = index_path + self.warc_download_prefix = warc_download_prefix + self.limit = limit + self.region_name = region_name + + def estimate_cost(self) -> CostEstimate: + if self.raw_query is not None: + return CostEstimate(n_crawls=None, engine='duckdb') + return CostEstimate(n_crawls=len(self.crawls) if self.crawls else None, engine='duckdb') + + def _build_query(self) -> str: + if self.raw_query is not None: + return self.raw_query + from_clause = _build_from_clause(self.index_path, self.crawls) + # crawl pruning is done in the FROM glob, so no crawl IN (...) in the WHERE + return build_sql( + from_clause, self.hostnames, crawls=None, limit=self.limit, + url_host_registered_domains=self.domains, + ) + + def iter_range_jobs(self) -> Iterator[RangeJob]: + if not _HAS_DUCKDB: + raise RuntimeError( + 'DuckDB engine requires optional dependencies. Install cdx_toolkit[duckdb].' + ) + + query = self._build_query() + logger.info('Executing DuckDB query: %s', query) + + con = duckdb.connect() + try: + con.execute('INSTALL httpfs; LOAD httpfs;') + con.execute(f"SET s3_region='{self.region_name}';") + # Be resilient to transient S3 read timeouts on large parquet partitions. + for stmt in ('SET http_timeout=120000;', 'SET http_retries=5;'): + try: + con.execute(stmt) + except Exception as e: # pragma: no cover - setting unsupported on this duckdb + logger.debug('duckdb setting skipped: %s (%r)', stmt, e) + cur = con.execute(query) + + col_names = [d[0] for d in cur.description] + validate_result_columns(col_names) + idx = {name: i for i, name in enumerate(col_names)} + + while True: + rows = cur.fetchmany(1000) + if not rows: + break + for row in rows: + warc_filename = row[idx['warc_filename']] + warc_url = join_warc_url(self.warc_download_prefix, warc_filename) + yield RangeJob( + url=warc_url, + offset=int(row[idx['warc_record_offset']]), + length=int(row[idx['warc_record_length']]), + filename=warc_filename, + ) + finally: + con.close() diff --git a/cdx_toolkit/filter_warc/sources/factory.py b/cdx_toolkit/filter_warc/sources/factory.py new file mode 100644 index 0000000..ba0e098 --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/factory.py @@ -0,0 +1,99 @@ +import logging +from typing import Optional, Tuple + +from cdx_toolkit.filter_warc.cdx_utils import get_cdx_paths +from cdx_toolkit.filter_warc.sources.base import RangeJobSource +from cdx_toolkit.filter_warc.sources.sql_base import build_athena_query, resolve_crawl_names + + +logger = logging.getLogger(__name__) + + +def make_source(args, *, warc_download_prefix: Optional[str], record_limit: int) -> RangeJobSource: + """Build the RangeJobSource selected by --target-source (+ --engine for sql). + + Centralises all source/engine validation (engine required iff sql; + hostnames/query/query-file mutual exclusivity; required connection options).""" + target = args.target_source + + if target == 'cdx': + from cdx_toolkit.filter_warc.sources.cdx import CdxSource + cdx_paths = get_cdx_paths(args.cdx_path, args.cdx_glob) + return CdxSource(cdx_paths, warc_download_prefix) + + if target == 'csv': + from cdx_toolkit.filter_warc.sources.csv import CsvSource + if not args.csv_path: + raise ValueError('--csv-path is required for --target-source csv') + return CsvSource(args.csv_path, warc_download_prefix) + + if target == 'sql': + return _make_sql_source(args, warc_download_prefix, record_limit) + + raise ValueError(f'Invalid target source: {target} (available: cdx, sql, csv)') + + +def _resolve_sql_query_spec(args) -> Tuple[Optional[str], Optional[list], Optional[list], Optional[list]]: + """Validate the query-defining flags and return (raw_sql, hostnames, domains, crawls). + + The guided path (--hostnames and/or --domains) and a raw query + (--query/--query-file) are mutually exclusive. For the guided path, --crawl is + resolved to concrete crawl names.""" + raw_sql = args.query + if args.query_file: + if raw_sql: + raise ValueError('--query and --query-file are mutually exclusive') + with open(args.query_file) as f: + raw_sql = f.read() + + has_guided = bool(args.hostnames) or bool(args.domains) + if raw_sql and has_guided: + raise ValueError('--query/--query-file are mutually exclusive with --hostnames/--domains') + if not raw_sql and not has_guided: + raise ValueError('the sql target requires --hostnames, --domains, or --query/--query-file') + + if raw_sql: + return raw_sql, None, None, None + + crawls = resolve_crawl_names(args.crawl) if args.crawl else None + return None, args.hostnames, args.domains, crawls + + +def _make_sql_source(args, warc_download_prefix, record_limit) -> RangeJobSource: + engine = args.engine + if not engine: + raise ValueError('--engine is required for --target-source sql (choices: athena, duckdb)') + + raw_sql, hostnames, domains, crawls = _resolve_sql_query_spec(args) + limit = 0 if record_limit is None else record_limit + + if engine == 'athena': + from cdx_toolkit.filter_warc.sources.athena import AthenaSource + if not args.athena_s3_output: + raise ValueError('--athena-s3-output is required for --engine athena') + database = args.athena_database or 'ccindex' + query = raw_sql if raw_sql else build_athena_query( + hostnames, crawls=crawls, limit=limit, url_host_registered_domains=domains, + ) + n_crawls = None if raw_sql else (len(crawls) if crawls else None) + return AthenaSource( + query=query, + database=database, + s3_output_location=args.athena_s3_output, + warc_download_prefix=warc_download_prefix, + n_crawls=n_crawls, + ) + + if engine == 'duckdb': + from cdx_toolkit.filter_warc.sources.duckdb import DuckDbSource + return DuckDbSource( + query=raw_sql, + hostnames=hostnames, + domains=domains, + crawls=crawls, + index_path=args.duckdb_index_path, + warc_download_prefix=warc_download_prefix, + limit=limit, + ) + + raise ValueError(f'Invalid --engine: {engine} (choices: athena, duckdb)') diff --git a/cdx_toolkit/filter_warc/sources/sql_base.py b/cdx_toolkit/filter_warc/sources/sql_base.py new file mode 100644 index 0000000..e466682 --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/sql_base.py @@ -0,0 +1,150 @@ +import logging +import re +from typing import List, Optional + +from cdx_toolkit.commoncrawl import normalize_crawl, get_cc_endpoints, match_cc_crawls + + +logger = logging.getLogger(__name__) + +# Required output columns that any index query (built or raw) must provide. +REQUIRED_RESULT_COLUMNS = ('warc_filename', 'warc_record_offset', 'warc_record_length') + +# Hostnames, TLDs and crawl names only ever contain these characters. Validating +# against this set both prevents SQL injection and catches malformed input early. +_SQL_LITERAL_RE = re.compile(r'^[A-Za-z0-9.\-]+$') + +# Default Common Crawl index mirror used to resolve --crawl to concrete crawl names. +CC_INDEX_MIRROR = 'https://index.commoncrawl.org/' + + +def escape_sql_literal(value: str) -> str: + """Validate and quote a value for safe inclusion in a SQL string literal. + + Only hostname/TLD/crawl-name characters (letters, digits, dot, hyphen) are + allowed, so the result cannot break out of the quotes or inject SQL.""" + if not isinstance(value, str) or not _SQL_LITERAL_RE.match(value): + raise ValueError( + f'Invalid value for SQL query literal: {value!r} ' + '(allowed characters: letters, digits, dot, hyphen)' + ) + return "'" + value + "'" + + +def build_where_sql( + url_host_names: Optional[List[str]] = None, + crawls: Optional[List[str]] = None, + url_host_registered_domains: Optional[List[str]] = None, +) -> str: + """Build the WHERE body (without the `WHERE` keyword) shared by all SQL engines. + + The guided filter matches on `url_host_name` (exact host, e.g. www.example.com) + and/or `url_host_registered_domain` (e.g. example.com, which also covers its + subdomains); predicates are OR-ed together. At least one host or domain is + required. If `crawls` is a non-empty list of crawl names (e.g. ['CC-MAIN-2025-33']), + a `crawl IN (...)` partition filter is added -- the main lever for reducing scan + cost. Engines differ only in their FROM clause (see build_sql).""" + url_host_names = url_host_names or [] + domains = url_host_registered_domains or [] + if not url_host_names and not domains: + raise ValueError('an index query requires at least one hostname or registered domain') + + tlds = sorted({v.split('.')[-1] for v in (list(url_host_names) + list(domains))}) + query_tlds = ' OR '.join(f'url_host_tld = {escape_sql_literal(t)}' for t in tlds) + + host_predicates = [f'url_host_name = {escape_sql_literal(h)}' for h in url_host_names] + host_predicates += [f'url_host_registered_domain = {escape_sql_literal(d)}' for d in domains] + query_hosts = ' OR '.join(host_predicates) + + clauses = [ + "subset = 'warc'", + f'({query_tlds}) -- help the query optimizer', + f'({query_hosts})', + ] + # TODO wire --from/--to into a fetch_time BETWEEN ... clause here + if crawls: + crawl_in = ', '.join(escape_sql_literal(c) for c in crawls) + clauses.append(f'crawl IN ({crawl_in})') + + return '\n AND '.join(clauses) + + +def build_sql( + from_clause: str, + url_host_names: Optional[List[str]] = None, + crawls: Optional[List[str]] = None, + limit: int = 0, + url_host_registered_domains: Optional[List[str]] = None, +) -> str: + """Assemble a full SELECT for the columnar index. + + `from_clause` is the text following FROM (e.g. `ccindex` for Athena, or a + `read_parquet(...)` expression for DuckDB).""" + where_sql = build_where_sql(url_host_names, crawls, url_host_registered_domains=url_host_registered_domains) + limit_sql = f'\n LIMIT {limit}' if limit and limit > 0 else '' + + return f""" + SELECT + warc_filename, warc_record_offset, warc_record_length + FROM {from_clause} + WHERE {where_sql}{limit_sql}""" + + +def build_athena_query( + url_host_names: Optional[List[str]] = None, + crawls: Optional[List[str]] = None, + limit: int = 0, + table: str = 'ccindex', + url_host_registered_domains: Optional[List[str]] = None, +) -> str: + """Athena flavour of build_sql (FROM ).""" + return build_sql( + table, url_host_names, crawls=crawls, limit=limit, + url_host_registered_domains=url_host_registered_domains, + ) + + +def validate_result_columns(column_names) -> None: + """Raise a clear error if the query result lacks the required columns.""" + missing = [c for c in REQUIRED_RESULT_COLUMNS if c not in (column_names or [])] + if missing: + raise ValueError( + 'Index query result is missing required columns: ' + ', '.join(missing) + + '. The query (including a raw --query) must SELECT ' + + ', '.join(REQUIRED_RESULT_COLUMNS) + '.' + ) + + +def join_warc_url(prefix: Optional[str], warc_filename: str) -> str: + """Join a download prefix with a warc_filename robustly. + + - if warc_filename is already absolute (contains '://'), return it unchanged + (supports custom queries / self-contained CSVs whose value is a full URL); + - if prefix is empty/None, return warc_filename unchanged; + - otherwise join with exactly one '/' (no double slash, no missing slash).""" + if '://' in warc_filename: + return warc_filename + if not prefix: + return warc_filename + return prefix.rstrip('/') + '/' + warc_filename.lstrip('/') + + +def endpoint_to_crawl_name(endpoint: str) -> str: + """Turn a collinfo cdx-api endpoint into its crawl name. + + e.g. 'https://index.commoncrawl.org/CC-MAIN-2025-33-index' -> 'CC-MAIN-2025-33'.""" + name = endpoint.rstrip('/').split('/')[-1] + if name.endswith('-index'): + name = name[:-len('-index')] + return name + + +def resolve_crawl_names(crawl_arg) -> List[str]: + """Resolve a --crawl value to concrete CC-MAIN crawl names for the partition filter. + + Reuses the CDX path's helpers so `--crawl` accepts the same forms (comma-separated + names or an integer for the most recent N crawls).""" + crawls = normalize_crawl([crawl_arg]) + raw_index_list = get_cc_endpoints(CC_INDEX_MIRROR) + matched = match_cc_crawls(crawls, raw_index_list) + return [endpoint_to_crawl_name(ep) for ep in matched] diff --git a/cdx_toolkit/filter_warc/warc_filter.py b/cdx_toolkit/filter_warc/warc_filter.py index ed1e0b1..8399286 100644 --- a/cdx_toolkit/filter_warc/warc_filter.py +++ b/cdx_toolkit/filter_warc/warc_filter.py @@ -2,20 +2,18 @@ import logging import statistics import sys -from typing import List, Literal, Optional, Dict +from typing import List, Optional, Dict from botocore.config import Config -from cdx_toolkit.filter_warc.athena_job_generator import get_range_jobs_from_athena from cdx_toolkit.filter_warc.s3_utils import ( is_s3_url, ) from cdx_toolkit.filter_warc.data_classes import RangeJob, RangePayload, ThroughputTracker from cdx_toolkit.filter_warc.warc_utils import create_new_writer_with_header -from cdx_toolkit.filter_warc.cdx_utils import ( - iter_cdx_index_from_path, -) +from cdx_toolkit.filter_warc.sources.base import RangeJobSource +from cdx_toolkit.filter_warc.sources.csv import RangeJobCsvWriter from cdx_toolkit.filter_warc.warc_utils import get_bytes_from_warc_record, get_metadata_record_from_path @@ -23,8 +21,6 @@ logger = logging.getLogger(__name__) -TargetSourceType = Literal['cdx', 'athena'] - class WARCFilter: """Filter or extract specific records from WARC files based on CDX indexes. @@ -47,11 +43,10 @@ def __init__( self, prefix_path: str, writer_info: Dict, - target_source: TargetSourceType = 'cdx', - cdx_paths: Optional[List[str]] = None, - athena_database: Optional[str] = None, - athena_query: Optional[str] = None, - athena_s3_output_location: Optional[str] = None, + source: RangeJobSource, + range_jobs_output: Optional[str] = None, + no_fetch: bool = False, + csv_self_contained: bool = False, writer_subprefix: Optional[str] = None, write_paths_as_metadata_records: Optional[List[str]] = None, record_limit: int = 0, @@ -75,11 +70,13 @@ def __init__( """Initialize the WARC filter. Args: - target_source: Source of filter targets (Athena query or CDX files). - cdx_paths: List of paths to CDX index files. - athena_database: Database for Athena query. - athena_query: Prepared Athena SQL string to execute (built by the caller). - athena_s3_output_location: S3 output location for Athena query. + source: RangeJobSource that yields the WARC ranges to repackage. + range_jobs_output: Optional path; if set, each generated RangeJob is + written to this CSV (materialization). + no_fetch: If True, only generate range jobs (and write range_jobs_output); + skip fetching/writing WARC records entirely. + csv_self_contained: If True, range_jobs_output stores full URLs instead of + relative filenames. prefix_path: Output path prefix for filtered WARC files. writer_info: Dictionary containing writer metadata. writer_subprefix: Optional subprefix for writer output paths. @@ -102,11 +99,10 @@ def __init__( min_part_size: Minimum part byte size for multipart uploads (default: 5 MiB). max_file_size: Maximum byte size for individual WARC files (default: 1 GiB). """ - self.cdx_paths = cdx_paths - self.target_source: TargetSourceType = target_source - self.athena_database = athena_database - self.athena_s3_output_location = athena_s3_output_location - self.athena_query = athena_query + self.source = source + self.range_jobs_output = range_jobs_output + self.no_fetch = no_fetch + self.csv_self_contained = csv_self_contained self.prefix_path = prefix_path self.writer_info = writer_info self.writer_subprefix = writer_subprefix @@ -131,7 +127,6 @@ def __init__( else max(int(self.num_readers / self.fetcher_to_consumer_ratio), 1) ) - # self.gzip = self.cdx_paths[0].endswith('.gz') if self.cdx_paths else False self.gzip = True self.warc_version = warc_version @@ -153,17 +148,15 @@ def filter(self) -> int: return -1 def needs_aws(self) -> bool: - """Returns true if AWS (S3/Athena) is needed at any stage. + """Returns true if the read/write (stage 2/3) S3 clients are needed. - Returns: - bool: True if AWS client is needed for any operation. + Sources own their own stage-1 resource (Athena client / DuckDB connection / + fsspec), so this only concerns WARC reads and output writes. With no_fetch + there are no reads/writes at all. """ - return ( - self.target_source == 'athena' # stage 1 - or (self.cdx_paths is not None and len(self.cdx_paths) > 0 and is_s3_url(self.cdx_paths[0])) # stage 1 - or is_s3_url(self.warc_download_prefix) # stage 3 - or is_s3_url(self.prefix_path) # stage 3 - ) + if self.no_fetch: + return False + return is_s3_url(self.warc_download_prefix) or is_s3_url(self.prefix_path) def get_boto3_base_config(self) -> Dict: """Get boto3 base configuration for AWS client. @@ -184,10 +177,10 @@ def get_boto3_base_config(self) -> Dict: ) async def get_aws_clients(self) -> Optional[Dict]: - """Return S3/Athena clients for job/read/write if needed. + """Return async S3 clients for WARC reads/writes if needed. - Returns: - Optional[aioboto3.Session.client]: S3/Athena client context manager if S3/Athena is needed, None otherwise. + Stage-1 clients/connections are owned by the source, so this only builds the + read/write S3 clients used to fetch WARC ranges and write output. Raises: SystemExit: If S3 is needed but Python version is < 3.9. @@ -198,23 +191,9 @@ async def get_aws_clients(self) -> Optional[Dict]: sys.exit(1) import aioboto3 - import boto3 session = aioboto3.Session() - # Lightweight config for CDX index reads - job_config = Config( - max_pool_connections=5, - read_timeout=60, - **self.get_boto3_base_config(), - ) - - if self.target_source == 'athena': - # Athena does not need an async client - job_client = boto3.client('athena', config=job_config) - else: - job_client = session.client('s3', config=job_config) - # High-throughput config for range reads read_config = Config( max_pool_connections=self.num_readers * 3, @@ -232,7 +211,6 @@ async def get_aws_clients(self) -> Optional[Dict]: ) return { - 'job': job_client, 'read': session.client('s3', config=read_config), 'write': session.client('s3', config=write_config), } @@ -245,86 +223,99 @@ async def filter_async(self) -> int: Returns: int: Number of records written. """ + # Materialize-only: just drain the source into the range-jobs CSV. + if self.no_fetch: + return await self._run_materialize_only() + range_jobs_queue: asyncio.Queue = asyncio.Queue(maxsize=self.range_jobs_queue_size) warc_records_queue: asyncio.Queue = asyncio.Queue(maxsize=self.warc_records_queue_size) if self.needs_aws(): clients = await self.get_aws_clients() - - # Handle mixed async/sync clients - Athena client is sync, S3 clients are async - if self.target_source == 'athena': - job_aws_client = clients['job'] # Sync client, no context manager needed - async with clients['read'] as read_aws_client, clients['write'] as write_aws_client: - return await self._run_filter_pipeline( - range_jobs_queue=range_jobs_queue, - warc_records_queue=warc_records_queue, - job_aws_client=job_aws_client, - read_s3_client=read_aws_client, - write_s3_client=write_aws_client, - ) - else: - async with clients['job'] as job_aws_client, clients['read'] as read_aws_client, clients[ - 'write' - ] as write_aws_client: - return await self._run_filter_pipeline( - range_jobs_queue=range_jobs_queue, - warc_records_queue=warc_records_queue, - job_aws_client=job_aws_client, - read_s3_client=read_aws_client, - write_s3_client=write_aws_client, - ) + async with clients['read'] as read_aws_client, clients['write'] as write_aws_client: + return await self._run_filter_pipeline( + range_jobs_queue=range_jobs_queue, + warc_records_queue=warc_records_queue, + read_s3_client=read_aws_client, + write_s3_client=write_aws_client, + ) else: return await self._run_filter_pipeline( range_jobs_queue=range_jobs_queue, warc_records_queue=warc_records_queue, ) + def _make_csv_writer(self) -> Optional[RangeJobCsvWriter]: + if self.range_jobs_output is None: + return None + return RangeJobCsvWriter(self.range_jobs_output, self_contained=self.csv_self_contained) + + async def _produce_range_jobs(self, range_jobs_queue: Optional[asyncio.Queue], csv_writer) -> int: + """Drive the (sync) source in a worker thread, feeding the async queue. + + Owns counting, the record limit, and (when a queue is present) emitting one + _STOP sentinel per reader in a finally -- so readers never hang even if the + source raises mid-iteration.""" + loop = asyncio.get_running_loop() + count = 0 + + def drain() -> int: + nonlocal count + for job in self.source.iter_range_jobs(): + if csv_writer is not None: + csv_writer.write(job) + if range_jobs_queue is not None: + asyncio.run_coroutine_threadsafe(range_jobs_queue.put(job), loop).result() + count += 1 + if self.record_limit and count >= self.record_limit: + logger.warning('Limit reached at %i', count) + break + return count + + try: + await asyncio.to_thread(drain) + finally: + if csv_writer is not None: + csv_writer.close() + if range_jobs_queue is not None: + for _ in range(self.num_readers): + await range_jobs_queue.put(_STOP) + + logger.info('Generated %d range jobs', count) + return count + + async def _run_materialize_only(self) -> int: + """--no-fetch: generate range jobs and write only the range-jobs CSV.""" + csv_writer = self._make_csv_writer() + if csv_writer is None: + logger.warning('--no-fetch set without --range-jobs-output: nothing to do') + count = await self._produce_range_jobs(range_jobs_queue=None, csv_writer=csv_writer) + logger.info('Materialized %d range jobs (no WARC fetch)', count) + return count + async def _run_filter_pipeline( self, range_jobs_queue: asyncio.Queue, warc_records_queue: asyncio.Queue, - job_aws_client=None, read_s3_client=None, write_s3_client=None, ) -> int: """Run the actual filter pipeline with or without S3 client. Args: - range_jobs_queue: Queue for range jobs from CDX index. + range_jobs_queue: Queue for range jobs from the source. warc_records_queue: Queue for WARC record payloads. - job_aws_client: Optional AWS (S3/Athena) client for jobs generation. read_s3_client: Optional S3 client for reads from S3. write_s3_client: Optional S3 client for writes S3. Returns: int: Number of records written. """ - # Fetch file paths and ranges (offset, length) from index files logger.info('Starting job generator, %d WARC readers, %d WARC writers', self.num_readers, self.num_writers) - # Generate range jobs from different target sources - if self.target_source == 'cdx': - job_generators = asyncio.create_task( - self.generate_range_jobs_from_cdx( - range_jobs_queue, - s3_client=job_aws_client, - ) - ) - elif self.target_source == 'athena': - job_generators = asyncio.create_task( - get_range_jobs_from_athena( - client=job_aws_client, - query=self.athena_query, - database=self.athena_database, - s3_output_location=self.athena_s3_output_location, - job_queue=range_jobs_queue, - queue_stop_object=_STOP, - warc_download_prefix=self.warc_download_prefix, - num_fetchers=self.num_readers, - ) - ) - else: - raise ValueError(f'Invalid target source: {self.target_source}') + # Generate range jobs from the configured source (bridged sync->async in a thread). + csv_writer = self._make_csv_writer() + job_generators = asyncio.create_task(self._produce_range_jobs(range_jobs_queue, csv_writer)) # Read WARC records based on file paths and ranges warc_readers = [ @@ -414,66 +405,6 @@ async def _coordinate_writer_shutdown(self, warc_readers: List[asyncio.Task], wa for _ in range(self.num_writers): await warc_records_queue.put(_STOP) - async def generate_range_jobs_from_single_cdx( - self, - cdx_path: str, - range_jobs_queue: asyncio.Queue, - count: int = 0, - ) -> int: - """Read a CDX file and generate range jobs based on URLs and offsets.""" - for warc_url, offset, length in iter_cdx_index_from_path( - cdx_path, warc_download_prefix=self.warc_download_prefix - ): - # Convert the CDX record back to a RangeJob - job = RangeJob(url=warc_url, offset=offset, length=length, records_count=1) - await range_jobs_queue.put(job) - count += 1 - - if self.record_limit > 0 and count >= self.record_limit: - logger.warning('Index limit reached at %i', count) - break - - return count - - async def generate_range_jobs_from_cdx( - self, - range_jobs_queue: asyncio.Queue, - s3_client=None, - ): - """Read the CDX paths, parse lines -> RangeJob (WARC files and offets) -> key_queue. - - Args: - range_jobs_queue: Queue to put RangeJob objects into. - s3_client: Optional S3 client for reading CDX indexes from S3. - """ - - logger.info('Range index limit: %i', self.record_limit) - count = 0 - - # Iterate over index files - # TODO this could be done in parallel - for index_path in self.cdx_paths: - # Fetch range queries from index - try: - count += await self.generate_range_jobs_from_single_cdx( - cdx_path=index_path, - range_jobs_queue=range_jobs_queue, - count=count, - ) - - except Exception as e: - logger.error('Failed to read CDX index from %s: %s', index_path, e) - - if self.record_limit > 0 and count >= self.record_limit: - logger.warning('Limit reached at %i', count) - break - - # signal fetchers to stop - for _ in range(self.num_readers): - await range_jobs_queue.put(_STOP) - - logger.info('Enqueued %d jobs from %s', count, index_path) - async def read_warc_records( self, reader_id: int, diff --git a/requirements.txt b/requirements.txt index 6b6efa9..5de9c65 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,9 @@ url-is-in>=0.1.1 fsspec[s3] botocore +# optional DuckDB SQL engine (install via cdx_toolkit[duckdb]) +duckdb + # used by Makefile pytest>=6.2.4 pytest-cov>=2.12.1 diff --git a/setup.py b/setup.py index 15ab048..4398197 100755 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ test_requirements = ['pytest', 'pytest-cov', 'flake8', 'responses'] optional_s3_requirements = ['fsspec[s3]', 'botocore'] +optional_duckdb_requirements = ['duckdb'] package_requirements = ['twine', 'setuptools', 'setuptools-scm'] @@ -19,10 +20,14 @@ extras_require = { 's3': optional_s3_requirements, + 'duckdb': optional_duckdb_requirements, 'test': test_requirements, # setup no longer tests, so make them an extra 'package': package_requirements, 'dev': package_requirements, - 'all': test_requirements + package_requirements + dev_requirements + optional_s3_requirements, + 'all': ( + test_requirements + package_requirements + dev_requirements + + optional_s3_requirements + optional_duckdb_requirements + ), } scripts = ['scripts/cdx_size', 'scripts/cdx_iter'] diff --git a/tests/conftest.py b/tests/conftest.py index 50caca6..ce0ce91 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,6 +30,12 @@ except ImportError: # pragma: no cover - exercised in minimal installs _HAS_FSSPEC = False +try: + import duckdb # noqa: F401 + _HAS_DUCKDB = True +except ImportError: # pragma: no cover - exercised in minimal installs + _HAS_DUCKDB = False + import functools from typing import Dict, Optional import requests @@ -162,6 +168,13 @@ def requires_aws_athena(func): ) +def requires_duckdb(func): + """Pytest decorator that skips a test if the optional duckdb dependency is missing.""" + return pytest.mark.skipif( + not _HAS_DUCKDB, reason='duckdb is not installed; install cdx_toolkit[duckdb] to enable DuckDB tests.' + )(func) + + @pytest.fixture def s3_tmpdir(): """S3 equivalent of tmpdir - provides a temporary S3 path and handles cleanup.""" diff --git a/tests/filter_warc/test_athena_command_validation.py b/tests/filter_warc/test_athena_command_validation.py deleted file mode 100644 index b352b99..0000000 --- a/tests/filter_warc/test_athena_command_validation.py +++ /dev/null @@ -1,84 +0,0 @@ -from argparse import Namespace -from unittest.mock import patch - -import pytest - -from cdx_toolkit.filter_warc import command -from cdx_toolkit.filter_warc.command import resolve_athena_query - - -def make_args(**kw): - defaults = dict( - athena_query=None, - athena_query_file=None, - athena_hostnames=None, - athena_database='ccindex', - athena_s3_output='s3://commoncrawl-ci-temp/athena-results/', - crawl=None, - limit=None, - ) - defaults.update(kw) - return Namespace(**defaults) - - -def test_query_and_hostnames_mutually_exclusive(): - args = make_args(athena_query='SELECT 1', athena_hostnames=['example.com']) - with pytest.raises(ValueError): - resolve_athena_query(args) - - -def test_query_and_query_file_mutually_exclusive(tmp_path): - f = tmp_path / 'q.sql' - f.write_text('SELECT 1') - args = make_args(athena_query='SELECT 1', athena_query_file=str(f)) - with pytest.raises(ValueError): - resolve_athena_query(args) - - -def test_neither_hostnames_nor_query(): - args = make_args() - with pytest.raises(ValueError): - resolve_athena_query(args) - - -def test_missing_database(): - args = make_args(athena_hostnames=['example.com'], athena_database=None) - with pytest.raises(ValueError): - resolve_athena_query(args) - - -def test_missing_s3_output(): - args = make_args(athena_hostnames=['example.com'], athena_s3_output=None) - with pytest.raises(ValueError): - resolve_athena_query(args) - - -def test_raw_query_is_unbounded(): - args = make_args(athena_query='SELECT warc_filename FROM x') - sql, n_crawls = resolve_athena_query(args) - assert sql == 'SELECT warc_filename FROM x' - assert n_crawls is None - - -def test_query_file_is_read(tmp_path): - f = tmp_path / 'q.sql' - f.write_text('SELECT warc_filename, warc_record_offset, warc_record_length FROM x') - args = make_args(athena_query_file=str(f)) - sql, n_crawls = resolve_athena_query(args) - assert 'warc_filename' in sql - assert n_crawls is None - - -def test_built_no_crawl_is_unbounded(): - args = make_args(athena_hostnames=['example.com']) - sql, n_crawls = resolve_athena_query(args) - assert 'example.com' in sql - assert n_crawls is None - - -def test_built_with_crawls_counts(): - args = make_args(athena_hostnames=['example.com'], crawl='CC-MAIN-2025-33,CC-MAIN-2025-30') - with patch.object(command, '_resolve_crawl_names', return_value=['CC-MAIN-2025-33', 'CC-MAIN-2025-30']): - sql, n_crawls = resolve_athena_query(args) - assert n_crawls == 2 - assert 'crawl IN' in sql diff --git a/tests/filter_warc/test_athena_job_generator.py b/tests/filter_warc/test_athena_job_generator.py deleted file mode 100644 index 05ec0b1..0000000 --- a/tests/filter_warc/test_athena_job_generator.py +++ /dev/null @@ -1,74 +0,0 @@ -import asyncio -from cdx_toolkit.filter_warc.warc_filter import _STOP -from cdx_toolkit.filter_warc.athena_job_generator import ( - get_databases, - get_range_jobs_from_athena, - build_athena_query, -) -from tests.conftest import TEST_ATHENA_DATABASE, TEST_ATHENA_S3_LOCATION, requires_aws_athena - -import boto3 - - -@requires_aws_athena -def test_get_databases(): - from botocore.config import Config - import boto3 - - boto_cfg = Config( - region_name='us-east-1', - ) - athena_client = boto3.client('athena', config=boto_cfg) - dbs = get_databases(client=athena_client) - assert 'ccindex' in dbs - - -@requires_aws_athena -def test_get_range_jobs_from_athena(): - async def run_test(): - # Setup test data - warc_download_prefix = 's3://commoncrawl' - - # Create asyncio queues - key_queue = asyncio.Queue() - - # Setup S3 client - from botocore.config import Config - - boto_cfg = Config( - region_name='us-east-1', - retries={'max_attempts': 3, 'mode': 'standard'}, - connect_timeout=10, - read_timeout=120, - ) - - athena_client = boto3.client('athena', config=boto_cfg) - - # Build the query and generate range jobs from Athena - query = build_athena_query( - ['oceancolor.sci.gsfc.nasa.gov'], - limit=10, # Use 10 records to ensure we have enough data - ) - await get_range_jobs_from_athena( - client=athena_client, - query=query, - database=TEST_ATHENA_DATABASE, - s3_output_location=TEST_ATHENA_S3_LOCATION, - job_queue=key_queue, - warc_download_prefix=warc_download_prefix, - num_fetchers=1, - queue_stop_object=_STOP, - ) - - # Collect all range jobs - range_jobs = [] - while not key_queue.empty(): - job = await key_queue.get() - if job is not _STOP: - range_jobs.append(job) - key_queue.task_done() - - assert len(range_jobs) == 10, "Invalid range jobs count" - - # Run the async test - asyncio.run(run_test()) diff --git a/tests/filter_warc/test_athena_prompt.py b/tests/filter_warc/test_athena_prompt.py index 6d36958..0b03b54 100644 --- a/tests/filter_warc/test_athena_prompt.py +++ b/tests/filter_warc/test_athena_prompt.py @@ -2,19 +2,30 @@ import pytest -from cdx_toolkit.filter_warc.command import confirm_athena_cost +from cdx_toolkit.filter_warc.command import confirm_cost +from cdx_toolkit.filter_warc.sources.base import CostEstimate + + +def est(n_crawls): + return CostEstimate(n_crawls=n_crawls, engine='athena') + + +def test_none_estimate_never_prompts(): + # cdx/csv sources return None -> no cost prompt + with patch('builtins.input') as inp: + confirm_cost(None, confirmed=False) + inp.assert_not_called() def test_small_crawl_set_no_prompt(): - # <= LARGE_CRAWL_SET_THRESHOLD -> considered cheap, no prompt with patch('builtins.input') as inp: - confirm_athena_cost(n_crawls=5, confirmed=False) + confirm_cost(est(5), confirmed=False) inp.assert_not_called() def test_confirmed_flag_bypasses_prompt(): with patch('builtins.input') as inp: - confirm_athena_cost(n_crawls=None, confirmed=True) + confirm_cost(est(None), confirmed=True) inp.assert_not_called() @@ -22,31 +33,31 @@ def test_large_crawl_set_non_tty_aborts(): with patch('cdx_toolkit.filter_warc.command.sys.stdin') as stdin: stdin.isatty.return_value = False with pytest.raises(SystemExit): - confirm_athena_cost(n_crawls=11, confirmed=False) + confirm_cost(est(11), confirmed=False) def test_unknown_crawls_non_tty_aborts(): with patch('cdx_toolkit.filter_warc.command.sys.stdin') as stdin: stdin.isatty.return_value = False with pytest.raises(SystemExit): - confirm_athena_cost(n_crawls=None, confirmed=False) + confirm_cost(est(None), confirmed=False) def test_tty_yes_proceeds(): with patch('cdx_toolkit.filter_warc.command.sys.stdin') as stdin, patch('builtins.input', return_value='y'): stdin.isatty.return_value = True - confirm_athena_cost(n_crawls=None, confirmed=False) # should not raise + confirm_cost(est(None), confirmed=False) # should not raise def test_tty_empty_answer_aborts(): with patch('cdx_toolkit.filter_warc.command.sys.stdin') as stdin, patch('builtins.input', return_value=''): stdin.isatty.return_value = True with pytest.raises(SystemExit): - confirm_athena_cost(n_crawls=None, confirmed=False) + confirm_cost(est(None), confirmed=False) def test_tty_no_answer_aborts(): with patch('cdx_toolkit.filter_warc.command.sys.stdin') as stdin, patch('builtins.input', return_value='n'): stdin.isatty.return_value = True with pytest.raises(SystemExit): - confirm_athena_cost(n_crawls=11, confirmed=False) + confirm_cost(est(11), confirmed=False) diff --git a/tests/filter_warc/test_athena_query_builder.py b/tests/filter_warc/test_athena_query_builder.py index 8e2cf97..304809e 100644 --- a/tests/filter_warc/test_athena_query_builder.py +++ b/tests/filter_warc/test_athena_query_builder.py @@ -1,12 +1,12 @@ import pytest -from cdx_toolkit.filter_warc.athena_job_generator import ( +from cdx_toolkit.filter_warc.sources.sql_base import ( build_athena_query, escape_sql_literal, validate_result_columns, join_warc_url, - run_athena_query, ) +from cdx_toolkit.filter_warc.sources.athena import run_athena_query class _FakeAthenaClient: @@ -43,6 +43,26 @@ def test_build_query_hostnames(): assert 'LIMIT' not in q +def test_build_query_domains_only(): + q = build_athena_query(url_host_registered_domains=['example.com']) + assert "url_host_registered_domain = 'example.com'" in q + assert "url_host_tld = 'com'" in q + assert 'url_host_name' not in q + + +def test_build_query_hostnames_and_domains(): + q = build_athena_query(['www.example.com'], url_host_registered_domains=['example.org']) + assert "url_host_name = 'www.example.com'" in q + assert "url_host_registered_domain = 'example.org'" in q + assert "url_host_tld = 'com'" in q + assert "url_host_tld = 'org'" in q + + +def test_build_query_requires_host_or_domain(): + with pytest.raises(ValueError): + build_athena_query() + + def test_build_query_with_crawls(): q = build_athena_query(['example.com'], crawls=['CC-MAIN-2025-33', 'CC-MAIN-2025-30']) assert "crawl IN ('CC-MAIN-2025-33', 'CC-MAIN-2025-30')" in q diff --git a/tests/filter_warc/test_cdx_utils.py b/tests/filter_warc/test_cdx_utils.py index 5ca7035..c6b2f8f 100644 --- a/tests/filter_warc/test_cdx_utils.py +++ b/tests/filter_warc/test_cdx_utils.py @@ -67,9 +67,9 @@ def mock_read_cdx_line(line, warc_download_prefix): # Should have 3 valid results despite 2 invalid lines being skipped assert len(results) == 3 - # Verify the valid results - assert results[0] == ('http://warc-prefix/test.warc.gz', 100, 500) - assert results[1] == ('http://warc-prefix/test2.warc.gz', 600, 300) - assert results[2] == ('http://warc-prefix/test3.warc.gz', 900, 200) + # Verify the valid results (url, offset, length, filename) + assert results[0] == ('http://warc-prefix/test.warc.gz', 100, 500, 'test.warc.gz') + assert results[1] == ('http://warc-prefix/test2.warc.gz', 600, 300, 'test2.warc.gz') + assert results[2] == ('http://warc-prefix/test3.warc.gz', 900, 200, 'test3.warc.gz') finally: os.unlink(tmp_file_path) diff --git a/tests/filter_warc/test_command.py b/tests/filter_warc/test_command.py index 49ac875..4da1233 100644 --- a/tests/filter_warc/test_command.py +++ b/tests/filter_warc/test_command.py @@ -34,7 +34,7 @@ def assert_cli_warc_by_cdx( args=[ '-v', '--limit=10', - 'warc_by_cdx', + 'repackage', f'--cdx-path={str(index_path)}', '--write-paths-as-metadata-records', str(metadata_record_path), @@ -119,6 +119,121 @@ def test_cli_warc_by_cdx_over_http_in_parallel(tmpdir, caplog): ) +def _produce_range_jobs_csv(tmpdir, csv_name, self_contained=False): + """Run `repackage` over the CDX fixture with --no-fetch to materialize a range-jobs CSV.""" + import csv as _csv + + index_path = fixture_path / 'filtered_CC-MAIN-2024-30_cdx-00187.gz' + csv_path = os.path.join(str(tmpdir), csv_name) + + args = [ + '--limit=10', + 'repackage', + '--target-source=cdx', + f'--cdx-path={str(index_path)}', + f'--range-jobs-output={csv_path}', + '--no-fetch', + ] + if self_contained: + args.append('--csv-self-contained') + main(args=args) + + with open(csv_path, newline='') as f: + rows = list(_csv.DictReader(f)) + return csv_path, rows + + +def _assert_repackaged_warc(warc_path, metadata_record_path): + """Inspect a repackaged WARC and assert the expected fixture content.""" + response_records = [] + response_contents = [] + metadata_record = None + metadata_record_headers = None + + with fsspec.open(warc_path, 'rb') as stream: + for record in ArchiveIterator(stream): + if record.rec_type == 'response': + response_records.append(record) + response_contents.append(record.content_stream().read().decode('utf-8', errors='ignore')) + if record.rec_type == 'metadata': + metadata_record = record + metadata_record_headers = record.rec_headers + + assert len(response_records) == 10, 'Invalid record count' + assert 'Catalogue en ligne Mission de France' in response_contents[0], 'Invalid response content' + assert 'dojo/dijit/themes/tundra/tundra' in response_contents[9], 'Invalid response content' + assert metadata_record is not None, 'Metadata record not set' + assert metadata_record_headers.get('WARC-Payload-Digest') == 'sha1:VXA2A5YUS3TAY36AUO6MACRMNOH5RXG2', ( + 'Invalid metadata block digest' + ) + + +def test_repackage_csv_materialize_filename(tmpdir): + """--no-fetch produces a filename-based range-jobs CSV without fetching WARCs.""" + csv_path, rows = _produce_range_jobs_csv(tmpdir, 'ranges.csv') + assert set(rows[0].keys()) == {'filename', 'offset', 'length'} + assert len(rows) == 10 + # No WARC was written for the default --prefix + assert not any(name.endswith('.warc.gz') for name in os.listdir(str(tmpdir))) + + +def test_repackage_csv_materialize_self_contained(tmpdir): + """--csv-self-contained produces a url-based range-jobs CSV.""" + csv_path, rows = _produce_range_jobs_csv(tmpdir, 'ranges_url.csv', self_contained=True) + assert set(rows[0].keys()) == {'url', 'offset', 'length'} + assert len(rows) == 10 + assert rows[0]['url'].startswith('https://data.commoncrawl.org/') + + +def test_cli_repackage_csv_roundtrip(tmpdir): + """End-to-end: produce a filename-based ranges CSV, then consume it and fetch over HTTP.""" + metadata_record_path = TEST_DATA_PATH / 'filter_cdx/whitelist_10_urls.txt' + csv_path, rows = _produce_range_jobs_csv(tmpdir, 'ranges.csv') + + base_prefix = str(tmpdir) + main( + args=[ + '-v', + 'repackage', + '--target-source=csv', + f'--csv-path={csv_path}', + '--write-paths-as-metadata-records', + str(metadata_record_path), + f'--prefix={base_prefix}/TEST_warc_by_index', + '--creator=foo', + '--operator=bob', + '--warc-download-prefix=https://data.commoncrawl.org', + ] + ) + + warc_path = os.path.join(base_prefix, 'TEST_warc_by_index-000000-001.warc.gz') + _assert_repackaged_warc(warc_path, metadata_record_path) + + +def test_cli_repackage_csv_roundtrip_self_contained(tmpdir): + """End-to-end with self-contained URLs: header auto-detected on read; no prefix needed.""" + metadata_record_path = TEST_DATA_PATH / 'filter_cdx/whitelist_10_urls.txt' + csv_path, rows = _produce_range_jobs_csv(tmpdir, 'ranges_url.csv', self_contained=True) + + base_prefix = str(tmpdir) + main( + args=[ + '-v', + 'repackage', + '--target-source=csv', + f'--csv-path={csv_path}', + '--write-paths-as-metadata-records', + str(metadata_record_path), + f'--prefix={base_prefix}/TEST_warc_by_index', + '--creator=foo', + '--operator=bob', + ] + ) + + warc_path = os.path.join(base_prefix, 'TEST_warc_by_index-000000-001.warc.gz') + _assert_repackaged_warc(warc_path, metadata_record_path) + + @requires_aws_s3 def test_cli_warc_by_cdx_over_s3(tmpdir, caplog): assert_cli_warc_by_cdx('s3://commoncrawl', base_prefix=tmpdir, caplog=caplog) @@ -182,7 +297,7 @@ def test_warc_by_cdx_no_index_files_found_exits(tmpdir, caplog): main( args=[ '-v', - 'warc_by_cdx', + 'repackage', f'--cdx-path={str(tmpdir)}', f'--prefix={str(tmpdir)}/TEST', '--cdx-glob=/nonexistent-pattern-*.gz', @@ -201,7 +316,7 @@ def test_warc_by_cdx_subprefix_and_metadata(tmpdir): args=[ '-v', '--limit=1', - 'warc_by_cdx', + 'repackage', f'--cdx-path={str(index_path)}', f'--prefix={str(tmpdir)}/TEST', '--subprefix=SUB', @@ -235,7 +350,7 @@ def test_warc_by_cdx_without_creator_operator(tmpdir): args=[ '-v', '--limit=1', - 'warc_by_cdx', + 'repackage', f'--cdx-path={str(index_path)}', f'--prefix={str(tmpdir)}/TEST_NO_META', ] @@ -276,14 +391,15 @@ def test_cli_warc_by_athena( args=[ '-v', '--limit=10', - 'warc_by_cdx', - '--target-source=athena', + 'repackage', + '--target-source=sql', + '--engine=athena', '--athena-database=ccindex', '--athena-s3-output=s3://commoncrawl-ci-temp/athena-results/', - '--athena-hostnames', + '--hostnames', 'oceancolor.sci.gsfc.nasa.gov', 'example.com', - '--confirm-athena-cost', + '--confirm-cost', f'--prefix={base_prefix}/TEST_warc_by_index', '--creator=foo', '--operator=bob', diff --git a/tests/filter_warc/test_csv_source.py b/tests/filter_warc/test_csv_source.py new file mode 100644 index 0000000..d6a6a84 --- /dev/null +++ b/tests/filter_warc/test_csv_source.py @@ -0,0 +1,66 @@ +import csv + +import pytest + +from cdx_toolkit.filter_warc.data_classes import RangeJob +from cdx_toolkit.filter_warc.sources.csv import RangeJobCsvWriter, CsvSource + + +def test_writer_filename_mode_and_read(tmp_path): + path = str(tmp_path / 'ranges.csv') + w = RangeJobCsvWriter(path, self_contained=False) + w.write(RangeJob( + url='https://data.commoncrawl.org/crawl-data/x.warc.gz', + offset=10, length=20, filename='crawl-data/x.warc.gz', + )) + w.close() + + with open(path, newline='') as f: + rows = list(csv.DictReader(f)) + assert rows[0] == {'filename': 'crawl-data/x.warc.gz', 'offset': '10', 'length': '20'} + + jobs = list(CsvSource(path, 'https://data.commoncrawl.org').iter_range_jobs()) + assert jobs[0].url == 'https://data.commoncrawl.org/crawl-data/x.warc.gz' + assert jobs[0].offset == 10 and jobs[0].length == 20 + assert jobs[0].filename == 'crawl-data/x.warc.gz' + + +def test_writer_self_contained_mode_and_read(tmp_path): + path = str(tmp_path / 'ranges_url.csv') + w = RangeJobCsvWriter(path, self_contained=True) + w.write(RangeJob(url='s3://commoncrawl/crawl-data/x.warc.gz', offset=5, length=7, filename='crawl-data/x.warc.gz')) + w.close() + + with open(path, newline='') as f: + rows = list(csv.DictReader(f)) + assert rows[0] == {'url': 's3://commoncrawl/crawl-data/x.warc.gz', 'offset': '5', 'length': '7'} + + # url used as-is regardless of the (ignored) prefix + jobs = list(CsvSource(path, 'https://ignored').iter_range_jobs()) + assert jobs[0].url == 's3://commoncrawl/crawl-data/x.warc.gz' + assert jobs[0].filename is None + + +def test_writer_filename_mode_requires_filename(tmp_path): + w = RangeJobCsvWriter(str(tmp_path / 'r.csv'), self_contained=False) + with pytest.raises(ValueError): + w.write(RangeJob(url='https://x/y.warc.gz', offset=1, length=2, filename=None)) + w.close() + + +def test_csv_source_missing_columns(tmp_path): + path = str(tmp_path / 'bad.csv') + with open(path, 'w') as f: + f.write('foo,bar\n1,2\n') + with pytest.raises(ValueError): + list(CsvSource(path, 'https://x').iter_range_jobs()) + + +def test_csv_source_tsv(tmp_path): + path = str(tmp_path / 'ranges.tsv') + with open(path, 'w') as f: + f.write('filename\toffset\tlength\n') + f.write('crawl-data/x.warc.gz\t10\t20\n') + jobs = list(CsvSource(path, 'https://data.commoncrawl.org').iter_range_jobs()) + assert jobs[0].url == 'https://data.commoncrawl.org/crawl-data/x.warc.gz' + assert jobs[0].offset == 10 and jobs[0].length == 20 diff --git a/tests/filter_warc/test_grouped_range_jobs.py b/tests/filter_warc/test_grouped_range_jobs.py index b046709..91ccbaf 100644 --- a/tests/filter_warc/test_grouped_range_jobs.py +++ b/tests/filter_warc/test_grouped_range_jobs.py @@ -5,7 +5,7 @@ def test_iter_cdx_index_from_test_data(): cdx_path = TEST_DATA_PATH / 'warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.gz' results = list(iter_cdx_index_from_path(str(cdx_path), 'http://warc-prefix')) - # [(url, offset, length)] + # [(url, offset, length, filename)] # sort results by offsets results.sort(key=lambda x: x[1]) @@ -20,8 +20,8 @@ def group_neighbor_chunks(items): current_chunk = [items[0]] for i in range(1, len(items)): - prev_url, prev_offset, prev_length = items[i - 1] - curr_url, curr_offset, curr_length = items[i] + prev_url, prev_offset, prev_length = items[i - 1][:3] + curr_url, curr_offset, curr_length = items[i][:3] # Check if current item is a neighbor (same URL and contiguous) if curr_url == prev_url and curr_offset == prev_offset + prev_length + 4: diff --git a/tests/filter_warc/test_make_source.py b/tests/filter_warc/test_make_source.py new file mode 100644 index 0000000..ee60199 --- /dev/null +++ b/tests/filter_warc/test_make_source.py @@ -0,0 +1,161 @@ +from argparse import Namespace +from unittest.mock import patch + +import pytest + +from cdx_toolkit.filter_warc.sources import make_source +from cdx_toolkit.filter_warc.sources.athena import AthenaSource +from cdx_toolkit.filter_warc.sources.cdx import CdxSource +from cdx_toolkit.filter_warc.sources.csv import CsvSource +from cdx_toolkit.filter_warc.sources.duckdb import DuckDbSource + + +def make_args(**kw): + defaults = dict( + target_source='cdx', + engine=None, + hostnames=None, + domains=None, + query=None, + query_file=None, + athena_database=None, + athena_s3_output='s3://commoncrawl-ci-temp/athena-results/', + duckdb_index_path='s3://commoncrawl/cc-index/table/cc-main/warc/', + csv_path=None, + cdx_path=None, + cdx_glob=None, + crawl=None, + ) + defaults.update(kw) + return Namespace(**defaults) + + +def build(**kw): + return make_source(make_args(**kw), warc_download_prefix='https://data.commoncrawl.org', record_limit=0) + + +# --- cdx / csv --- + +def test_cdx_source(): + src = build(target_source='cdx', cdx_path='/tmp/index.cdx.gz') + assert isinstance(src, CdxSource) + assert src.estimate_cost() is None + + +def test_csv_requires_path(): + with pytest.raises(ValueError): + build(target_source='csv') + + +def test_csv_source(): + src = build(target_source='csv', csv_path='/tmp/ranges.csv') + assert isinstance(src, CsvSource) + assert src.estimate_cost() is None + + +# --- sql: engine + query validation --- + +def test_sql_requires_engine(): + with pytest.raises(ValueError): + build(target_source='sql', hostnames=['example.com']) + + +def test_sql_hostnames_and_query_mutually_exclusive(): + with pytest.raises(ValueError): + build(target_source='sql', engine='athena', hostnames=['example.com'], query='SELECT 1') + + +def test_sql_query_and_query_file_mutually_exclusive(tmp_path): + f = tmp_path / 'q.sql' + f.write_text('SELECT 1') + with pytest.raises(ValueError): + build(target_source='sql', engine='athena', query='SELECT 1', query_file=str(f)) + + +def test_sql_neither_hostnames_domains_nor_query(): + with pytest.raises(ValueError): + build(target_source='sql', engine='athena') + + +def test_sql_domains_and_query_mutually_exclusive(): + with pytest.raises(ValueError): + build(target_source='sql', engine='athena', domains=['example.com'], query='SELECT 1') + + +def test_athena_domains_only(): + src = build(target_source='sql', engine='athena', domains=['example.com']) + assert isinstance(src, AthenaSource) + assert 'url_host_registered_domain = \'example.com\'' in src.query + assert 'url_host_name' not in src.query + + +def test_athena_hostnames_and_domains_combined(): + src = build(target_source='sql', engine='athena', hostnames=['www.example.com'], domains=['example.org']) + assert "url_host_name = 'www.example.com'" in src.query + assert "url_host_registered_domain = 'example.org'" in src.query + # TLDs from both hostnames and domains + assert "url_host_tld = 'com'" in src.query + assert "url_host_tld = 'org'" in src.query + + +def test_duckdb_domains_only(): + src = build(target_source='sql', engine='duckdb', domains=['commoncrawl.org']) + q = src._build_query() + assert "url_host_registered_domain = 'commoncrawl.org'" in q + assert 'read_parquet' in q + + +# --- athena --- + +def test_athena_requires_s3_output(): + with pytest.raises(ValueError): + build(target_source='sql', engine='athena', hostnames=['example.com'], athena_s3_output=None) + + +def test_athena_built_no_crawl_unbounded(): + src = build(target_source='sql', engine='athena', hostnames=['example.com']) + assert isinstance(src, AthenaSource) + est = src.estimate_cost() + assert est.engine == 'athena' and est.n_crawls is None + assert 'ccindex' in src.query and 'example.com' in src.query + + +def test_athena_raw_query_unbounded(): + src = build(target_source='sql', engine='athena', query='SELECT warc_filename FROM x') + assert src.query == 'SELECT warc_filename FROM x' + assert src.estimate_cost().n_crawls is None + + +def test_athena_with_crawls_counts(): + with patch( + 'cdx_toolkit.filter_warc.sources.factory.resolve_crawl_names', + return_value=['CC-MAIN-2025-33', 'CC-MAIN-2025-30'], + ): + src = build( + target_source='sql', engine='athena', hostnames=['example.com'], + crawl='CC-MAIN-2025-33,CC-MAIN-2025-30', + ) + assert src.estimate_cost().n_crawls == 2 + assert 'crawl IN' in src.query + + +# --- duckdb --- + +def test_duckdb_built_query_has_read_parquet_and_partition(): + with patch( + 'cdx_toolkit.filter_warc.sources.factory.resolve_crawl_names', + return_value=['CC-MAIN-2026-17'], + ): + src = build(target_source='sql', engine='duckdb', hostnames=['commoncrawl.org'], crawl='CC-MAIN-2026-17') + assert isinstance(src, DuckDbSource) + assert src.estimate_cost().n_crawls == 1 + query = src._build_query() + assert 'read_parquet' in query + assert 'crawl=CC-MAIN-2026-17' in query + assert 'commoncrawl.org' in query + + +def test_duckdb_no_crawl_unbounded(): + src = build(target_source='sql', engine='duckdb', hostnames=['commoncrawl.org']) + assert src.estimate_cost().n_crawls is None + assert 'crawl=*' in src._build_query() diff --git a/tests/filter_warc/test_producer.py b/tests/filter_warc/test_producer.py new file mode 100644 index 0000000..ba6efcd --- /dev/null +++ b/tests/filter_warc/test_producer.py @@ -0,0 +1,68 @@ +import asyncio +import csv + +from cdx_toolkit.filter_warc.data_classes import RangeJob +from cdx_toolkit.filter_warc.sources.base import RangeJobSource +from cdx_toolkit.filter_warc.warc_filter import WARCFilter, _STOP + + +class FakeSource(RangeJobSource): + def __init__(self, jobs, raise_after=None): + self.jobs = jobs + self.raise_after = raise_after + + def iter_range_jobs(self): + for i, job in enumerate(self.jobs): + if self.raise_after is not None and i == self.raise_after: + raise RuntimeError('boom') + yield job + + +def _jobs(n): + return [ + RangeJob(url=f'https://data.commoncrawl.org/{i}.warc.gz', offset=i, length=1, filename=f'{i}.warc.gz') + for i in range(n) + ] + + +def test_no_fetch_materializes_csv(tmp_path): + out = str(tmp_path / 'ranges.csv') + wf = WARCFilter( + source=FakeSource(_jobs(3)), + prefix_path=str(tmp_path / 'out'), + writer_info={'isPartOf': 'test'}, + range_jobs_output=out, + no_fetch=True, + ) + n = wf.filter() + assert n == 3 + with open(out, newline='') as f: + rows = list(csv.DictReader(f)) + assert len(rows) == 3 + assert set(rows[0].keys()) == {'filename', 'offset', 'length'} + + +def test_producer_emits_stops_even_when_source_raises(tmp_path): + """Regression: a source that raises mid-iteration must still release the readers.""" + wf = WARCFilter( + source=FakeSource(_jobs(2), raise_after=1), + prefix_path=str(tmp_path), + writer_info={}, + n_parallel=3, + ) + + async def run(): + queue: asyncio.Queue = asyncio.Queue() + try: + await wf._produce_range_jobs(queue, None) + except RuntimeError: + pass + items = [] + while not queue.empty(): + items.append(await queue.get()) + return items + + items = asyncio.run(run()) + stops = sum(1 for it in items if it is _STOP) + assert wf.num_readers == 3 + assert stops == 3 diff --git a/tests/filter_warc/test_sql_sources_gated.py b/tests/filter_warc/test_sql_sources_gated.py new file mode 100644 index 0000000..61fcde5 --- /dev/null +++ b/tests/filter_warc/test_sql_sources_gated.py @@ -0,0 +1,105 @@ +"""Gated end-to-end tests for the SQL sources (Athena, DuckDB). + +These query only a single crawl partition for a single host (cheap, partition-pruned +-- never an all-crawls scan), materialize a range-jobs CSV with --no-fetch, then +consume it and fetch the WARCs over HTTP. They require AWS credentials (and duckdb) +and are skipped in CI. +""" +import csv +import os + +import fsspec +from warcio.archiveiterator import ArchiveIterator + +from cdx_toolkit.cli import main +from tests.conftest import requires_aws_athena, requires_aws_s3, requires_duckdb, TEST_ATHENA_S3_LOCATION + + +CRAWL = 'CC-MAIN-2026-17' +HOST = 'commoncrawl.org' + + +def _produce_and_consume(tmpdir, produce_args): + csv_path = os.path.join(str(tmpdir), 'ranges.csv') + + # Produce: run the (cheap, single-crawl) SQL query, write range jobs, no WARC fetch. + main(args=produce_args + [f'--range-jobs-output={csv_path}', '--no-fetch']) + + with open(csv_path, newline='') as f: + rows = list(csv.DictReader(f)) + assert len(rows) > 0, 'expected at least one successful warc fetch for the host/crawl' + + # Consume: read the CSV, fetch the WARC records over HTTP, write a new WARC. + base_prefix = str(tmpdir) + main( + args=[ + 'repackage', + '--target-source=csv', + f'--csv-path={csv_path}', + f'--prefix={base_prefix}/TEST_sql', + '--warc-download-prefix=https://data.commoncrawl.org', + ] + ) + + warc_path = os.path.join(base_prefix, 'TEST_sql-000000-001.warc.gz') + response_count = 0 + with fsspec.open(warc_path, 'rb') as stream: + for record in ArchiveIterator(stream): + if record.rec_type == 'response': + response_count += 1 + target = record.rec_headers.get_header('WARC-Target-URI') or '' + assert HOST in target, f'unexpected target URI: {target}' + + assert response_count == len(rows), 'every range job should yield a response record' + + +@requires_aws_athena +def test_repackage_sql_athena_e2e(tmpdir): + _produce_and_consume( + tmpdir, + [ + '--crawl', CRAWL, + 'repackage', + '--target-source=sql', + '--engine=athena', + '--athena-database=ccindex', + f'--athena-s3-output={TEST_ATHENA_S3_LOCATION}', + '--hostnames', HOST, + '--confirm-cost', + ], + ) + + +@requires_aws_s3 +@requires_duckdb +def test_repackage_sql_duckdb_e2e(tmpdir): + _produce_and_consume( + tmpdir, + [ + '--crawl', CRAWL, + 'repackage', + '--target-source=sql', + '--engine=duckdb', + '--hostnames', HOST, + '--confirm-cost', + ], + ) + + +@requires_aws_s3 +@requires_duckdb +def test_repackage_sql_duckdb_domain_e2e(tmpdir): + # Domain filtering (url_host_registered_domain) also matches subdomains; bound it + # with --limit to keep the live verification cheap. + _produce_and_consume( + tmpdir, + [ + '--crawl', CRAWL, + '--limit', '10', + 'repackage', + '--target-source=sql', + '--engine=duckdb', + '--domains', HOST, + '--confirm-cost', + ], + ) diff --git a/tests/filter_warc/test_warc_filter.py b/tests/filter_warc/test_warc_filter.py index 7d8307b..b157f61 100644 --- a/tests/filter_warc/test_warc_filter.py +++ b/tests/filter_warc/test_warc_filter.py @@ -5,9 +5,13 @@ from tests.conftest import TEST_DATA_PATH from cdx_toolkit.filter_warc.warc_filter import WARCFilter +from cdx_toolkit.filter_warc.sources.cdx import CdxSource fixture_path = TEST_DATA_PATH / 'warc_by_cdx' +# A throwaway source for unit tests that only exercise reader/writer/rotate/log methods. +_FAKE_SOURCE = CdxSource(['/fake/path'], 'https://data.commoncrawl.org') + def test_filter_keyboard_interrupt_handling(caplog): """Test that KeyboardInterrupt is properly handled in the filter method.""" @@ -16,7 +20,7 @@ def test_filter_keyboard_interrupt_handling(caplog): # Set log level to capture WARNING messages caplog.set_level(logging.WARNING, logger='cdx_toolkit.filter_warc.warc_filter') - warc_filter = WARCFilter(cdx_paths=['/fake/path'], prefix_path='/fake/prefix', writer_info={'writer_id': 1}) + warc_filter = WARCFilter(source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}) # Mock filter_async to raise KeyboardInterrupt with patch.object(warc_filter, 'filter_async', side_effect=KeyboardInterrupt('Simulated user interrupt')): @@ -35,7 +39,7 @@ def test_rotate_files_no_rotation_needed(): async def run_test(): warc_filter = WARCFilter( - cdx_paths=['/fake/path'], + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000, # 1KB limit @@ -76,7 +80,7 @@ def test_rotate_files_rotation_needed_without_resource_records(): async def run_test(): warc_filter = WARCFilter( - cdx_paths=['/fake/path'], + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000, # 1KB limit @@ -134,7 +138,7 @@ def test_rotate_files_rotation_needed_with_metadata_records(): async def run_test(): warc_filter = WARCFilter( - cdx_paths=['/fake/path'], + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000, # 1KB limit @@ -189,7 +193,7 @@ def test_rotate_files_no_max_file_size_set(): async def run_test(): warc_filter = WARCFilter( - cdx_paths=['/fake/path'], + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=None, # No limit @@ -230,7 +234,7 @@ def test_rotate_files_edge_case_exact_limit(): async def run_test(): warc_filter = WARCFilter( - cdx_paths=['/fake/path'], + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000, # 1KB limit @@ -271,7 +275,7 @@ def test_rotate_files_edge_case_just_over_limit(): async def run_test(): warc_filter = WARCFilter( - cdx_paths=['/fake/path'], + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000, # 1KB limit @@ -317,7 +321,7 @@ def test_rotate_files_kwargs_passed_through(): async def run_test(): warc_filter = WARCFilter( - cdx_paths=['/fake/path'], prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000 + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000 ) mock_writer = AsyncMock() @@ -370,7 +374,7 @@ async def run_test(): caplog.set_level(logging.INFO, logger='cdx_toolkit.filter_warc.warc_filter') warc_filter = WARCFilter( - cdx_paths=['/fake/path'], prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000 + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000 ) mock_writer = AsyncMock() @@ -402,9 +406,11 @@ async def run_test(): def test_log_writer(caplog): """Test log writer.""" + import logging + caplog.set_level(logging.INFO, logger='cdx_toolkit.filter_warc.warc_filter') warc_filter = WARCFilter( - cdx_paths=['/fake/path'], + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, log_every_n=2, @@ -419,9 +425,11 @@ def test_log_writer(caplog): def test_log_reader(caplog): """Test log reader.""" + import logging + caplog.set_level(logging.INFO, logger='cdx_toolkit.filter_warc.warc_filter') warc_filter = WARCFilter( - cdx_paths=['/fake/path'], + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, log_every_n=2,