diff --git a/cdx_toolkit/cli.py b/cdx_toolkit/cli.py index c650d5c..f74610d 100644 --- a/cdx_toolkit/cli.py +++ b/cdx_toolkit/cli.py @@ -12,8 +12,8 @@ from cdx_toolkit.filter_cdx.command import run_filter_cdx from cdx_toolkit.filter_cdx.args import add_filter_cdx_args -from cdx_toolkit.filter_warc.command import run_warcer_by_cdx -from cdx_toolkit.filter_warc.args import add_warcer_by_cdx_args +from cdx_toolkit.filter_warc.command import run_repackage +from cdx_toolkit.filter_warc.args import add_repackage_args LOGGER = logging.getLogger(__name__) @@ -124,12 +124,12 @@ def main(args=None): warc.add_argument('url') warc.set_defaults(func=warcer) - warc_by_cdx = subparsers.add_parser( - 'warc_by_cdx', - help='iterate over capture content based on an CDX index file, creating a warc' + repackage = subparsers.add_parser( + 'repackage', + help='repackage WARC ranges from a CDX/SQL/CSV source into a new WARC' ) - add_warcer_by_cdx_args(warc_by_cdx) - warc_by_cdx.set_defaults(func=run_warcer_by_cdx) + add_repackage_args(repackage) + repackage.set_defaults(func=run_repackage) filter_cdx = subparsers.add_parser('filter_cdx', help='Filter CDX files based on SURT prefixes whitelist') add_filter_cdx_args(filter_cdx) diff --git a/cdx_toolkit/filter_warc/args.py b/cdx_toolkit/filter_warc/args.py index 1fa3a2e..71afa90 100644 --- a/cdx_toolkit/filter_warc/args.py +++ b/cdx_toolkit/filter_warc/args.py @@ -5,12 +5,13 @@ logger = logging.getLogger(__name__) -def add_warcer_by_cdx_args(parser: argparse.ArgumentParser): +def add_repackage_args(parser: argparse.ArgumentParser): + # --- CDX source --- parser.add_argument( '--cdx-path', type=str, default=None, - help='Path to CDX index file (local or remote, e.g. S3). Required if target source is set to `cdx`.', + help='Path to CDX index file (local or remote, e.g. S3). Used when --target-source cdx.', ) parser.add_argument( '--cdx-glob', @@ -18,46 +19,95 @@ def add_warcer_by_cdx_args(parser: argparse.ArgumentParser): default=None, help='a glob pattern for read from multiple CDX indices', ) + # --- SQL source (--target-source sql --engine athena|duckdb) --- parser.add_argument( - '--athena-hostnames', + '--engine', type=str, - nargs="+", default=None, - help=('Hostnames to filter for via Athena (whitelist). Use this OR --athena-query/' - '--athena-query-file (mutually exclusive) when target source is `athena`.'), + choices=['athena', 'duckdb'], + help='SQL engine for the columnar index. Required when --target-source sql.', ) parser.add_argument( - '--athena-query', + '--hostnames', type=str, + nargs='+', default=None, - help=('Raw Athena SQL to run instead of the hostname-based query (power users). The query ' - 'must SELECT the columns warc_filename, warc_record_offset, warc_record_length. ' - 'Mutually exclusive with --athena-hostnames and --athena-query-file.'), + help=('Exact hostnames (url_host_name, e.g. www.example.com) to filter for via the SQL ' + 'index. Combine with --domains; mutually exclusive with --query/--query-file. ' + 'Combine with the global --crawl to restrict the scan to specific crawls ' + '(strongly recommended for cost).'), ) parser.add_argument( - '--athena-query-file', + '--domains', type=str, + nargs='+', default=None, - help='Path to a file containing the raw Athena SQL (alternative to --athena-query).', + help=('Registered domains (url_host_registered_domain, e.g. example.com) to filter for via ' + 'the SQL index; also matches subdomains. Combine with --hostnames; mutually exclusive ' + 'with --query/--query-file.'), + ) + parser.add_argument( + '--query', + type=str, + default=None, + help=('Raw SQL to run instead of the hostname-based query (power users). Must SELECT the ' + 'columns warc_filename, warc_record_offset, warc_record_length. Engine-specific ' + 'dialect. Mutually exclusive with --hostnames and --query-file.'), + ) + parser.add_argument( + '--query-file', + type=str, + default=None, + help='Path to a file containing the raw SQL (alternative to --query).', ) parser.add_argument( '--athena-database', type=str, default=None, - help='Athena database. Required if target source is set to `athena`.', + help='Athena database (engine=athena). Defaults to `ccindex`.', ) parser.add_argument( '--athena-s3-output', type=str, default=None, - help='Athena S3 output location. Required if target source is set to `athena`.', + help='Athena S3 output location (engine=athena). Required for engine=athena.', + ) + parser.add_argument( + '--duckdb-index-path', + type=str, + default='s3://commoncrawl/cc-index/table/cc-main/warc/', + help='Base S3 path to the CC columnar index parquet (engine=duckdb).', ) parser.add_argument( - '--confirm-athena-cost', + '--confirm-cost', action='store_true', - help=('Skip the Athena cost-confirmation prompt and run even unpartitioned / large-scan ' + help=('Skip the cost-confirmation prompt and run even unpartitioned / large-scan SQL ' 'queries. Athena bills per TB scanned; restrict with --crawl to reduce cost.'), ) + # --- CSV source --- + parser.add_argument( + '--csv-path', + type=str, + default=None, + help='Path to a range-jobs CSV/TSV (local or remote). Used when --target-source csv.', + ) + # --- Range-jobs materialization (any source) --- + parser.add_argument( + '--range-jobs-output', + type=str, + default=None, + help='If set, write each generated RangeJob to this CSV (filename,offset,length by default).', + ) + parser.add_argument( + '--no-fetch', + action='store_true', + help='Only generate range jobs (write --range-jobs-output); skip fetching/writing WARCs.', + ) + parser.add_argument( + '--csv-self-contained', + action='store_true', + help='Write full URLs (url,offset,length) to --range-jobs-output instead of relative filenames.', + ) parser.add_argument('--prefix', default='TEST', help='prefix for the output warc filename') parser.add_argument( '--subprefix', @@ -126,8 +176,9 @@ def add_warcer_by_cdx_args(parser: argparse.ArgumentParser): '--target-source', action='store', default='cdx', - help=('Source from that the filter targets are loaded (available options: `cdx`, `athena`; ' - 'defaults to `cdx`). For `athena`, use the global --crawl to restrict the scan to ' - 'specific crawls (strongly recommended; Athena bills per TB scanned).'), + choices=['cdx', 'sql', 'csv'], + help=('Where range jobs come from: `cdx` (index files), `sql` (columnar index via ' + '--engine athena|duckdb), or `csv` (a range-jobs CSV). Defaults to `cdx`. For `sql`, ' + 'use the global --crawl to restrict the scan to specific crawls (recommended for cost).'), ) return parser diff --git a/cdx_toolkit/filter_warc/athena_job_generator.py b/cdx_toolkit/filter_warc/athena_job_generator.py deleted file mode 100644 index c4c4323..0000000 --- a/cdx_toolkit/filter_warc/athena_job_generator.py +++ /dev/null @@ -1,251 +0,0 @@ -import asyncio -import logging -import re -import time -from typing import Any, Iterable, List, Optional - -from cdx_toolkit.filter_warc.data_classes import RangeJob - - -logger = logging.getLogger(__name__) - -# Required output columns that any Athena query (built or raw) must provide. -REQUIRED_RESULT_COLUMNS = ('warc_filename', 'warc_record_offset', 'warc_record_length') - -# Athena pricing is ~$5 per TB scanned (used for the post-run cost estimate). -ATHENA_USD_PER_TB = 5.0 - -# Hostnames, TLDs and crawl names only ever contain these characters. Validating -# against this set both prevents SQL injection and catches malformed input early. -_SQL_LITERAL_RE = re.compile(r'^[A-Za-z0-9.\-]+$') - - -def escape_sql_literal(value: str) -> str: - """Validate and quote a value for safe inclusion in an Athena SQL string literal. - - Only hostname/TLD/crawl-name characters (letters, digits, dot, hyphen) are - allowed, so the result cannot break out of the quotes or inject SQL.""" - if not isinstance(value, str) or not _SQL_LITERAL_RE.match(value): - raise ValueError( - f'Invalid value for Athena query literal: {value!r} ' - '(allowed characters: letters, digits, dot, hyphen)' - ) - return "'" + value + "'" - - -def build_athena_query( - url_host_names: List[str], - crawls: Optional[List[str]] = None, - limit: int = 0, - table: str = 'ccindex', -) -> str: - """Build the Athena SQL returning warc_filename/offset/length for the hostnames. - - CommonCrawl provides an index via AWS Athena that we can use to find the file - names, offsets, and byte lengths for WARC filtering. See - https://commoncrawl.org/blog/index-to-warc-files-and-urls-in-columnar-format - - If `crawls` is a non-empty list of crawl names (e.g. ['CC-MAIN-2025-33']), a - `crawl IN (...)` partition filter is added -- this is the main lever for - reducing Athena scan cost.""" - if not url_host_names: - raise ValueError('build_athena_query requires at least one hostname') - - tlds = sorted({url.split('.')[-1] for url in url_host_names}) - query_tlds = ' OR '.join(f'url_host_tld = {escape_sql_literal(tld)}' for tld in tlds) - query_hostnames = ' OR '.join(f'url_host_name = {escape_sql_literal(h)}' for h in url_host_names) - - where_clauses = [ - "subset = 'warc'", - f'({query_tlds}) -- help the query optimizer', - f'({query_hostnames})', - ] - # TODO wire --from/--to into a fetch_time BETWEEN ... clause here - if crawls: - crawl_in = ', '.join(escape_sql_literal(c) for c in crawls) - where_clauses.append(f'crawl IN ({crawl_in})') - - where_sql = '\n AND '.join(where_clauses) - query_limit = f'\n LIMIT {limit}' if limit > 0 else '' - - return f""" - SELECT - warc_filename, warc_record_offset, warc_record_length - FROM {table} - WHERE {where_sql}{query_limit}""" - - -def validate_result_columns(column_names) -> None: - """Raise a clear error if the query result lacks the required columns.""" - missing = [c for c in REQUIRED_RESULT_COLUMNS if c not in (column_names or [])] - if missing: - raise ValueError( - 'Athena query result is missing required columns: ' + ', '.join(missing) + - '. The query (including a raw --athena-query) must SELECT ' + - ', '.join(REQUIRED_RESULT_COLUMNS) + '.' - ) - - -def join_warc_url(prefix: Optional[str], warc_filename: str) -> str: - """Join a download prefix with a warc_filename robustly. - - - if warc_filename is already absolute (contains '://'), return it unchanged - (supports custom queries whose warc_filename is a full s3://-/https:// URL); - - if prefix is empty/None, return warc_filename unchanged; - - otherwise join with exactly one '/' (no double slash, no missing slash).""" - if '://' in warc_filename: - return warc_filename - if not prefix: - return warc_filename - return prefix.rstrip('/') + '/' + warc_filename.lstrip('/') - - -def run_athena_query(client, query: str, database: str, s3_output_location: str, max_wait_time: int = 300) -> str: - """Start an Athena query and block until it completes; return the execution id. - - Raises if the query does not reach the SUCCEEDED state. If the wait is - interrupted (Ctrl-C / cancellation) or times out, the query is cancelled - server-side so we don't keep paying for a scan whose results we'll never read.""" - logger.info('Executing Athena query: %s', query) - - response = client.start_query_execution( - QueryString=query, - QueryExecutionContext={'Database': database}, - ResultConfiguration={'OutputLocation': s3_output_location}, - ) - - query_execution_id = response['QueryExecutionId'] - logger.info('Query execution started. ID: %s', query_execution_id) - - try: - status = _wait_for_query_completion(client, query_execution_id, max_wait_time) - except BaseException: - # Ctrl-C, asyncio cancellation, or timeout: stop the query server-side - # to bound the Athena scan cost, then propagate the original exception. - _stop_query(client, query_execution_id) - raise - - if status != 'SUCCEEDED': - raise Exception(f'Query failed with status: {status}') - - return query_execution_id - - -def _stop_query(client, query_execution_id: str) -> None: - """Best-effort cancellation of a running Athena query.""" - logger.warning('Cancelling Athena query %s ...', query_execution_id) - try: - client.stop_query_execution(QueryExecutionId=query_execution_id) - logger.warning('Athena query %s cancelled', query_execution_id) - except Exception as e: # pragma: no cover - best-effort cleanup - logger.warning('Failed to cancel Athena query %s: %r', query_execution_id, e) - - -def report_query_cost(client, query_execution_id: str) -> None: - """Log the bytes scanned and an estimated USD cost for a completed query.""" - try: - response = client.get_query_execution(QueryExecutionId=query_execution_id) - scanned = response['QueryExecution'].get('Statistics', {}).get('DataScannedInBytes') - except Exception as e: # pragma: no cover - best-effort reporting - logger.debug('unable to read Athena query statistics: %r', e) - return - - if scanned is None: - return - - gb = scanned / 1e9 - usd = scanned / 1e12 * ATHENA_USD_PER_TB - logger.info('Athena scanned %.2f GB, estimated cost ~$%.4f', gb, usd) - - -async def get_range_jobs_from_athena( - client, - query: str, - database: str, - s3_output_location: str, - job_queue: asyncio.Queue, - queue_stop_object: Any, - warc_download_prefix: str, - num_fetchers: int, - max_wait_time: int = 300, -) -> int: - """Execute a prepared Athena query and enqueue a RangeJob per result row. - - The query string is built and validated by the caller (see build_athena_query - and cdx_toolkit.filter_warc.command). This function only executes it, maps the - results to RangeJob objects, pushes them to the asyncio queue, signals the - fetchers to stop, and logs the scan cost.""" - count = 0 - - query_execution_id = run_athena_query(client, query, database, s3_output_location, max_wait_time) - - for range_job in iter_range_jobs(client, query_execution_id, warc_download_prefix): - await job_queue.put(range_job) - count += 1 - - report_query_cost(client, query_execution_id) - - # Signal fetchers to stop - for _ in range(num_fetchers): - await job_queue.put(queue_stop_object) - - logger.info('Athena query enqueued %d jobs', count) - - return count - - -def _wait_for_query_completion(client, query_execution_id: str, max_wait_time: int) -> str: - """Wait for query to complete and return final status""" - start_time = time.time() - - while time.time() - start_time < max_wait_time: - response = client.get_query_execution(QueryExecutionId=query_execution_id) - - status = response['QueryExecution']['Status']['State'] - logger.info(f'Query status: {status}') - - if status in ['SUCCEEDED', 'FAILED', 'CANCELLED']: - if status == 'FAILED': - error_reason = response['QueryExecution']['Status'].get('StateChangeReason', 'Unknown error') - logger.info(f'Query failed: {error_reason}') - return status - - time.sleep(2) - - raise TimeoutError(f'Query did not complete within {max_wait_time} seconds') - - -def iter_range_jobs(client, query_execution_id: str, warc_download_prefix: str) -> Iterable[RangeJob]: - """Retrieve query results and convert each row to a RangeJob""" - # Get query results - paginator = client.get_paginator('get_query_results') - page_iterator = paginator.paginate(QueryExecutionId=query_execution_id) - column_names = None - - for page in page_iterator: - rows = page['ResultSet']['Rows'] - - # Get column names from first page - if column_names is None and rows: - column_names = [col['VarCharValue'] for col in rows[0]['Data']] - validate_result_columns(column_names) - rows = rows[1:] # Skip header row - - # Process data rows - for row in rows: - row_data = [] - for cell in row['Data']: - value = cell.get('VarCharValue', None) - row_data.append(value) - - row = dict(zip(column_names, row_data)) - - warc_url = join_warc_url(warc_download_prefix, row['warc_filename']) - - yield RangeJob(url=warc_url, offset=int(row['warc_record_offset']), length=int(row['warc_record_length'])) - - -def get_databases(client) -> list: - """Get list of available databases""" - response = client.list_databases(CatalogName='AwsDataCatalog') - return [db['Name'] for db in response['DatabaseList']] diff --git a/cdx_toolkit/filter_warc/cdx_utils.py b/cdx_toolkit/filter_warc/cdx_utils.py index 6b0c584..45f32a2 100644 --- a/cdx_toolkit/filter_warc/cdx_utils.py +++ b/cdx_toolkit/filter_warc/cdx_utils.py @@ -27,7 +27,7 @@ def get_index_as_string_from_path( return f.read() -def read_cdx_line(line: str, warc_download_prefix: str) -> Tuple[str, int, int]: +def read_cdx_line(line: str, warc_download_prefix: str) -> Tuple[str, int, int, str]: cols = line.split(' ', maxsplit=2) if len(cols) == 3: @@ -49,10 +49,10 @@ def read_cdx_line(line: str, warc_download_prefix: str) -> Tuple[str, int, int]: warc_url = warc_download_prefix + '/' + filename - return (warc_url, offset, length) + return (warc_url, offset, length, filename) -def iter_cdx_index_from_path(index_path: str, warc_download_prefix: str) -> Iterable[Tuple[str, int, int]]: +def iter_cdx_index_from_path(index_path: str, warc_download_prefix: str) -> Iterable[Tuple[str, int, int, str]]: """ Iterate CDX records from a file path (gzipped; local or remote). """ diff --git a/cdx_toolkit/filter_warc/command.py b/cdx_toolkit/filter_warc/command.py index 2328f64..8a2d187 100644 --- a/cdx_toolkit/filter_warc/command.py +++ b/cdx_toolkit/filter_warc/command.py @@ -1,7 +1,5 @@ -from cdx_toolkit.filter_warc.cdx_utils import get_cdx_paths from cdx_toolkit.filter_warc.warc_filter import WARCFilter -from cdx_toolkit.filter_warc.athena_job_generator import build_athena_query -from cdx_toolkit.commoncrawl import normalize_crawl, get_cc_endpoints, match_cc_crawls +from cdx_toolkit.filter_warc.sources import make_source from cdx_toolkit.utils import get_version @@ -14,116 +12,54 @@ logger = logging.getLogger(__name__) -# A built Athena query restricted to at most this many crawls is considered cheap -# enough to run without a cost-confirmation prompt. +# A SQL query restricted to at most this many crawls is considered cheap enough to +# run without a cost-confirmation prompt. LARGE_CRAWL_SET_THRESHOLD = 10 -# Default Common Crawl index mirror used to resolve --crawl to concrete crawl names. -CC_INDEX_MIRROR = 'https://index.commoncrawl.org/' +def confirm_cost(estimate, confirmed) -> None: + """Prompt before running a potentially expensive index scan. -def _endpoint_to_crawl_name(endpoint: str) -> str: - """Turn a collinfo cdx-api endpoint into its crawl name. - - e.g. 'https://index.commoncrawl.org/CC-MAIN-2025-33-index' -> 'CC-MAIN-2025-33'.""" - name = endpoint.rstrip('/').split('/')[-1] - if name.endswith('-index'): - name = name[:-len('-index')] - return name - - -def _resolve_crawl_names(crawl_arg) -> list: - """Resolve a --crawl value to concrete CC-MAIN crawl names for the partition filter. - - Reuses the CDX path's helpers so `--crawl` accepts the same forms (comma-separated - names or an integer for the most recent N crawls).""" - crawls = normalize_crawl([crawl_arg]) - raw_index_list = get_cc_endpoints(CC_INDEX_MIRROR) - matched = match_cc_crawls(crawls, raw_index_list) - return [_endpoint_to_crawl_name(ep) for ep in matched] - - -def resolve_athena_query(args): - """Validate the Athena args and return (sql, n_crawls). - - n_crawls is the number of crawls the query is restricted to, or None when the - query is unrestricted (scans all crawls) or its pruning cannot be verified - (raw --athena-query/--athena-query-file).""" - raw_sql = args.athena_query - if args.athena_query_file: - if raw_sql: - raise ValueError('--athena-query and --athena-query-file are mutually exclusive') - with open(args.athena_query_file) as f: - raw_sql = f.read() - - if raw_sql and args.athena_hostnames: - raise ValueError('--athena-query/--athena-query-file are mutually exclusive with --athena-hostnames') - if not raw_sql and not args.athena_hostnames: - raise ValueError('athena target requires either --athena-hostnames or --athena-query/--athena-query-file') - - if not args.athena_database: - raise ValueError('--athena-database is required for target source `athena`') - if not args.athena_s3_output: - raise ValueError('--athena-s3-output is required for target source `athena`') - - if raw_sql: - # Crawl-partition pruning of a raw query cannot be verified -> treat as unbounded. - return raw_sql, None - - # Guided/built path - limit = 0 if args.limit is None else args.limit - if args.crawl: - crawl_names = _resolve_crawl_names(args.crawl) - sql = build_athena_query(args.athena_hostnames, crawls=crawl_names, limit=limit) - return sql, len(crawl_names) - - sql = build_athena_query(args.athena_hostnames, limit=limit) - return sql, None - - -def confirm_athena_cost(n_crawls, confirmed) -> None: - """Prompt before running a potentially expensive Athena query. - - A built query restricted to <= LARGE_CRAWL_SET_THRESHOLD crawls runs without a - prompt. Otherwise (no crawl filter, a large crawl set, or unverifiable raw SQL) - we confirm interactively, abort in non-interactive sessions, unless `confirmed` - (--confirm-athena-cost) is set.""" - if confirmed: + `estimate` is a CostEstimate (or None for sources that never bill, e.g. cdx/csv). + A scan bounded to <= LARGE_CRAWL_SET_THRESHOLD crawls runs without a prompt. + Otherwise (no crawl filter, a large crawl set, or unverifiable raw SQL) we confirm + interactively, abort in non-interactive sessions, unless `confirmed` (--confirm-cost).""" + if confirmed or estimate is None: return + + n_crawls = estimate.n_crawls if n_crawls is not None and n_crawls <= LARGE_CRAWL_SET_THRESHOLD: return + engine = estimate.engine if n_crawls is None: - reason = ('This Athena query is not restricted to specific crawls (or uses custom SQL whose ' - 'crawl partition pruning could not be verified) and may scan ALL crawls.') + reason = (f'This {engine} query is not restricted to specific crawls (or uses custom SQL ' + 'whose crawl partition pruning could not be verified) and may scan ALL crawls.') else: - reason = (f'This Athena query is restricted to {n_crawls} crawls (more than ' + reason = (f'This {engine} query is restricted to {n_crawls} crawls (more than ' f'{LARGE_CRAWL_SET_THRESHOLD}) and may scan a large amount of data.') if not sys.stdin.isatty(): raise SystemExit( - reason + ' Refusing to run a potentially expensive Athena scan in non-interactive mode. ' - 'Restrict with --crawl (<=10 crawls), or pass --confirm-athena-cost.' + reason + ' Refusing to run a potentially expensive scan in non-interactive mode. ' + 'Restrict with --crawl (<=10 crawls), or pass --confirm-cost.' ) logger.warning(reason) - answer = input('Athena bills per TB scanned. Proceed? [y/N] ') + answer = input('Index SQL scans can be expensive (Athena bills per TB scanned). Proceed? [y/N] ') if answer.strip().lower() not in ('y', 'yes'): raise SystemExit('Aborted by user.') -def run_warcer_by_cdx(args, cmdline): - """Like warcer but fetches WARC records based on one or more CDX index files. - - The CDX files can be filtered using the `filter_cdx` commands based a given URL/SURT list. +def run_repackage(args, cmdline): + """Repackage WARC records from a pluggable range-job source (cdx / sql / csv). Approach: - - Iterate over one or more CDX files to extract capture object (file, offset, length) - - Fetch WARC record based on capture object - - Write to new WARC file including metadata records with index. - - The CDX metadata record is written to the WARC directly before for response records that matches to the CDX. + - Generate RangeJobs (WARC file + byte range) from the selected source. + - Optionally materialize them to a CSV (--range-jobs-output; --no-fetch to skip fetching). + - Fetch each WARC record and write a new WARC, including metadata records. """ - logger.info('Filtering WARC files based on CDX') + logger.info('Repackaging WARC files (target source: %s)', args.target_source) # Start timing start_time = time.time() @@ -142,7 +78,7 @@ def run_warcer_by_cdx(args, cmdline): 'isPartOf': ispartof, 'description': args.description if args.description - else 'warc extraction based on CDX generated with: ' + cmdline, + else 'warc extraction generated with: ' + cmdline, 'format': 'WARC file version 1.0', } if args.creator: @@ -154,33 +90,22 @@ def run_warcer_by_cdx(args, cmdline): log_every_n = args.log_every_n limit = 0 if args.limit is None else args.limit prefix_path = str(args.prefix) - prefix_fs, prefix_fs_path = fsspec.url_to_fs(prefix_path) - # make sure the base dir exists - prefix_fs.makedirs(prefix_fs._parent(prefix_fs_path), exist_ok=True) + # Build the source (validates source/engine/query options) and confirm scan cost + # up front (synchronously) before launching the pipeline. + source = make_source(args, warc_download_prefix=args.warc_download_prefix, record_limit=limit) + confirm_cost(source.estimate_cost(), args.confirm_cost) - # target source handling - athena_query = None - if args.target_source == 'cdx': - cdx_paths = get_cdx_paths( - args.cdx_path, - args.cdx_glob, - ) - elif args.target_source == "athena": - cdx_paths = None - # Build/validate the Athena query up front (synchronously) so we can warn - # about expensive (unpartitioned / large) scans before launching the pipeline. - athena_query, n_crawls = resolve_athena_query(args) - confirm_athena_cost(n_crawls, args.confirm_athena_cost) - else: - raise ValueError(f'Invalid target source specified: {args.target_source} (available: cdx, athena)') + # make sure the output base dir exists (only needed when actually writing WARCs) + if not args.no_fetch: + prefix_fs, prefix_fs_path = fsspec.url_to_fs(prefix_path) + prefix_fs.makedirs(prefix_fs._parent(prefix_fs_path), exist_ok=True) warc_filter = WARCFilter( - target_source=args.target_source, - cdx_paths=cdx_paths, - athena_database=args.athena_database, - athena_s3_output_location=args.athena_s3_output, - athena_query=athena_query, + source=source, + range_jobs_output=args.range_jobs_output, + no_fetch=args.no_fetch, + csv_self_contained=args.csv_self_contained, prefix_path=prefix_path, writer_info=info, writer_subprefix=args.subprefix, @@ -190,7 +115,6 @@ def run_warcer_by_cdx(args, cmdline): warc_download_prefix=args.warc_download_prefix, n_parallel=n_parallel, max_file_size=args.size, - # writer_kwargs=writer_kwargs, ) records_n = warc_filter.filter() diff --git a/cdx_toolkit/filter_warc/data_classes.py b/cdx_toolkit/filter_warc/data_classes.py index 8f36325..40f190e 100644 --- a/cdx_toolkit/filter_warc/data_classes.py +++ b/cdx_toolkit/filter_warc/data_classes.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from cdx_toolkit.filter_warc.s3_utils import is_s3_url, parse_s3_uri, with_retries -from typing import Tuple +from typing import Optional, Tuple from cdx_toolkit.myrequests import myrequests_get @@ -47,6 +47,10 @@ class RangeJob: offset: int length: int records_count: int = 1 + # Relative WARC filename (e.g. crawl-data/...warc.gz) as known by the source, + # used when materializing a non-self-contained range-jobs CSV. `url` stays + # authoritative for fetching. + filename: Optional[str] = None def is_s3(self): return is_s3_url(self.url) diff --git a/cdx_toolkit/filter_warc/sources/__init__.py b/cdx_toolkit/filter_warc/sources/__init__.py new file mode 100644 index 0000000..d1208cf --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/__init__.py @@ -0,0 +1,4 @@ +from cdx_toolkit.filter_warc.sources.base import RangeJobSource, CostEstimate +from cdx_toolkit.filter_warc.sources.factory import make_source + +__all__ = ['RangeJobSource', 'CostEstimate', 'make_source'] diff --git a/cdx_toolkit/filter_warc/sources/athena.py b/cdx_toolkit/filter_warc/sources/athena.py new file mode 100644 index 0000000..203048d --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/athena.py @@ -0,0 +1,177 @@ +import logging +import time +from typing import Iterator, Iterable, Optional + +from cdx_toolkit.filter_warc.data_classes import RangeJob +from cdx_toolkit.filter_warc.sources.base import RangeJobSource, CostEstimate +from cdx_toolkit.filter_warc.sources.sql_base import validate_result_columns, join_warc_url + + +logger = logging.getLogger(__name__) + +# Athena pricing is ~$5 per TB scanned (used for the post-run cost estimate). +ATHENA_USD_PER_TB = 5.0 + + +class AthenaSource(RangeJobSource): + """RangeJobs from a query against the CC columnar index via AWS Athena.""" + + def __init__( + self, + *, + query: str, + database: str, + s3_output_location: str, + warc_download_prefix: Optional[str], + n_crawls: Optional[int] = None, + region_name: str = 'us-east-1', + max_wait_time: int = 300, + ): + self.query = query + self.database = database + self.s3_output_location = s3_output_location + self.warc_download_prefix = warc_download_prefix + self._n_crawls = n_crawls + self.region_name = region_name + self.max_wait_time = max_wait_time + + def estimate_cost(self) -> CostEstimate: + return CostEstimate(n_crawls=self._n_crawls, engine='athena') + + def _make_client(self): + import boto3 + from botocore.config import Config + + config = Config( + region_name=self.region_name, + read_timeout=60, + retries={'max_attempts': 3, 'mode': 'adaptive'}, + ) + return boto3.client('athena', config=config) + + def iter_range_jobs(self) -> Iterator[RangeJob]: + client = self._make_client() + query_execution_id = run_athena_query( + client, self.query, self.database, self.s3_output_location, self.max_wait_time + ) + try: + for job in iter_range_jobs(client, query_execution_id, self.warc_download_prefix): + yield job + finally: + report_query_cost(client, query_execution_id) + + +def run_athena_query(client, query: str, database: str, s3_output_location: str, max_wait_time: int = 300) -> str: + """Start an Athena query and block until it completes; return the execution id. + + Raises if the query does not reach the SUCCEEDED state. If the wait is + interrupted (Ctrl-C / cancellation) or times out, the query is cancelled + server-side so we don't keep paying for a scan whose results we'll never read.""" + logger.info('Executing Athena query: %s', query) + + response = client.start_query_execution( + QueryString=query, + QueryExecutionContext={'Database': database}, + ResultConfiguration={'OutputLocation': s3_output_location}, + ) + + query_execution_id = response['QueryExecutionId'] + logger.info('Query execution started. ID: %s', query_execution_id) + + try: + status = _wait_for_query_completion(client, query_execution_id, max_wait_time) + except BaseException: + # Ctrl-C, asyncio cancellation, or timeout: stop the query server-side + # to bound the Athena scan cost, then propagate the original exception. + _stop_query(client, query_execution_id) + raise + + if status != 'SUCCEEDED': + raise Exception(f'Query failed with status: {status}') + + return query_execution_id + + +def _stop_query(client, query_execution_id: str) -> None: + """Best-effort cancellation of a running Athena query.""" + logger.warning('Cancelling Athena query %s ...', query_execution_id) + try: + client.stop_query_execution(QueryExecutionId=query_execution_id) + logger.warning('Athena query %s cancelled', query_execution_id) + except Exception as e: # pragma: no cover - best-effort cleanup + logger.warning('Failed to cancel Athena query %s: %r', query_execution_id, e) + + +def report_query_cost(client, query_execution_id: str) -> None: + """Log the bytes scanned and an estimated USD cost for a completed query.""" + try: + response = client.get_query_execution(QueryExecutionId=query_execution_id) + scanned = response['QueryExecution'].get('Statistics', {}).get('DataScannedInBytes') + except Exception as e: # pragma: no cover - best-effort reporting + logger.debug('unable to read Athena query statistics: %r', e) + return + + if scanned is None: + return + + gb = scanned / 1e9 + usd = scanned / 1e12 * ATHENA_USD_PER_TB + logger.info('Athena scanned %.2f GB, estimated cost ~$%.4f', gb, usd) + + +def _wait_for_query_completion(client, query_execution_id: str, max_wait_time: int) -> str: + """Wait for query to complete and return final status""" + start_time = time.time() + + while time.time() - start_time < max_wait_time: + response = client.get_query_execution(QueryExecutionId=query_execution_id) + + status = response['QueryExecution']['Status']['State'] + logger.info(f'Query status: {status}') + + if status in ['SUCCEEDED', 'FAILED', 'CANCELLED']: + if status == 'FAILED': + error_reason = response['QueryExecution']['Status'].get('StateChangeReason', 'Unknown error') + logger.info(f'Query failed: {error_reason}') + return status + + time.sleep(2) + + raise TimeoutError(f'Query did not complete within {max_wait_time} seconds') + + +def iter_range_jobs(client, query_execution_id: str, warc_download_prefix: Optional[str]) -> Iterable[RangeJob]: + """Retrieve query results and convert each row to a RangeJob""" + paginator = client.get_paginator('get_query_results') + page_iterator = paginator.paginate(QueryExecutionId=query_execution_id) + column_names = None + + for page in page_iterator: + rows = page['ResultSet']['Rows'] + + # Get column names from first page + if column_names is None and rows: + column_names = [col['VarCharValue'] for col in rows[0]['Data']] + validate_result_columns(column_names) + rows = rows[1:] # Skip header row + + # Process data rows + for row in rows: + row_data = [cell.get('VarCharValue', None) for cell in row['Data']] + row = dict(zip(column_names, row_data)) + + warc_filename = row['warc_filename'] + warc_url = join_warc_url(warc_download_prefix, warc_filename) + + yield RangeJob( + url=warc_url, + offset=int(row['warc_record_offset']), + length=int(row['warc_record_length']), + filename=warc_filename, + ) + + +def get_databases(client) -> list: + """Get list of available databases""" + response = client.list_databases(CatalogName='AwsDataCatalog') + return [db['Name'] for db in response['DatabaseList']] diff --git a/cdx_toolkit/filter_warc/sources/base.py b/cdx_toolkit/filter_warc/sources/base.py new file mode 100644 index 0000000..f833cce --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/base.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod +from typing import Iterator, NamedTuple, Optional + +from cdx_toolkit.filter_warc.data_classes import RangeJob + + +class CostEstimate(NamedTuple): + """Describes the scan a source is about to run, for the cost-confirmation guard. + + n_crawls is the number of crawls the scan is bounded to, or None when the scan + is unbounded / its pruning cannot be verified (e.g. raw SQL, all-crawls glob).""" + + n_crawls: Optional[int] + engine: str + + +class RangeJobSource(ABC): + """A source of RangeJobs for the repackage pipeline. + + A source owns its own stage-1 resource (Athena client, DuckDB connection, or an + fsspec file handle) and yields RangeJobs synchronously; the pipeline orchestrator + bridges the sync generator into the async fetch/write stages, and owns queueing, + the record limit, counting, and stop-sentinel emission.""" + + def estimate_cost(self) -> Optional[CostEstimate]: + """Return a CostEstimate for the cost guard, or None for sources that never + incur a per-scan charge (cdx files, csv).""" + return None + + @abstractmethod + def iter_range_jobs(self) -> Iterator[RangeJob]: + """Yield RangeJobs. Implementations open their own client/connection lazily + and close it in a finally. Each RangeJob carries `url` (authoritative for + fetching) and, where known, the relative `filename`.""" + raise NotImplementedError diff --git a/cdx_toolkit/filter_warc/sources/cdx.py b/cdx_toolkit/filter_warc/sources/cdx.py new file mode 100644 index 0000000..bb9356d --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/cdx.py @@ -0,0 +1,28 @@ +import logging +from typing import Iterator, List + +from cdx_toolkit.filter_warc.cdx_utils import iter_cdx_index_from_path +from cdx_toolkit.filter_warc.data_classes import RangeJob +from cdx_toolkit.filter_warc.sources.base import RangeJobSource + + +logger = logging.getLogger(__name__) + + +class CdxSource(RangeJobSource): + """RangeJobs read from one or more CDX index files (local or remote via fsspec).""" + + def __init__(self, cdx_paths: List[str], warc_download_prefix: str): + self.cdx_paths = cdx_paths + self.warc_download_prefix = warc_download_prefix + + def iter_range_jobs(self) -> Iterator[RangeJob]: + for index_path in self.cdx_paths: + try: + for warc_url, offset, length, filename in iter_cdx_index_from_path( + index_path, self.warc_download_prefix + ): + yield RangeJob(url=warc_url, offset=offset, length=length, filename=filename) + except Exception as e: + # Preserve the previous behaviour of skipping a bad index file. + logger.error('Failed to read CDX index from %s: %s', index_path, e) diff --git a/cdx_toolkit/filter_warc/sources/csv.py b/cdx_toolkit/filter_warc/sources/csv.py new file mode 100644 index 0000000..f9ac5aa --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/csv.py @@ -0,0 +1,90 @@ +import csv +import logging +from typing import Iterator, Optional + +import fsspec + +from cdx_toolkit.filter_warc.data_classes import RangeJob +from cdx_toolkit.filter_warc.sources.base import RangeJobSource +from cdx_toolkit.filter_warc.sources.sql_base import join_warc_url + + +logger = logging.getLogger(__name__) + +FILENAME_FIELDS = ['filename', 'offset', 'length'] +URL_FIELDS = ['url', 'offset', 'length'] + + +class RangeJobCsvWriter: + """Write RangeJobs to a CSV (local or remote via fsspec). + + Default mode writes the relative `filename` column (the consumer prepends the + WARC download prefix); `self_contained` mode writes the full `url` column.""" + + def __init__(self, path: str, self_contained: bool = False): + self.path = path + self.self_contained = self_contained + self._fields = URL_FIELDS if self_contained else FILENAME_FIELDS + self._ctx = fsspec.open(path, 'wt', newline='') + self._fh = self._ctx.__enter__() + self._writer = csv.DictWriter(self._fh, fieldnames=self._fields) + self._writer.writeheader() + + def write(self, job: RangeJob) -> None: + if self.self_contained: + row = {'url': job.url, 'offset': job.offset, 'length': job.length} + else: + if job.filename is None: + raise ValueError( + 'cannot write a non-self-contained range-jobs CSV: RangeJob.filename is ' + 'missing; pass --csv-self-contained to write full URLs instead' + ) + row = {'filename': job.filename, 'offset': job.offset, 'length': job.length} + self._writer.writerow(row) + + def close(self) -> None: + if self._ctx is not None: + self._ctx.__exit__(None, None, None) + self._ctx = None + self._fh = None + + +class CsvSource(RangeJobSource): + """RangeJobs read from a CSV/TSV. + + Mode is auto-detected from the header: a `url` column => self-contained (used + as-is); a `filename` column => the WARC download prefix is prepended. TSV is + detected from a `.tsv`/`.tsv.gz` extension; `.gz` inputs are decompressed.""" + + def __init__(self, path: str, warc_download_prefix: Optional[str]): + self.path = path + self.warc_download_prefix = warc_download_prefix + + def iter_range_jobs(self) -> Iterator[RangeJob]: + path = str(self.path) + delimiter = '\t' if path.endswith(('.tsv', '.tsv.gz')) else ',' + compression = 'gzip' if path.endswith('.gz') else None + + with fsspec.open(self.path, 'rt', newline='', compression=compression) as fh: + reader = csv.DictReader(fh, delimiter=delimiter) + fields = set(reader.fieldnames or []) + if 'url' in fields: + mode_url = True + elif 'filename' in fields: + mode_url = False + else: + raise ValueError( + f'range-jobs CSV {self.path} must have a `url` or `filename` column ' + f'(got header: {reader.fieldnames})' + ) + + for row in reader: + offset = int(row['offset']) + length = int(row['length']) + if mode_url: + url = row['url'] + filename = None + else: + filename = row['filename'] + url = join_warc_url(self.warc_download_prefix, filename) + yield RangeJob(url=url, offset=offset, length=length, filename=filename) diff --git a/cdx_toolkit/filter_warc/sources/duckdb.py b/cdx_toolkit/filter_warc/sources/duckdb.py new file mode 100644 index 0000000..2513bd2 --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/duckdb.py @@ -0,0 +1,113 @@ +import logging +from typing import Iterator, List, Optional + +from cdx_toolkit.filter_warc.data_classes import RangeJob +from cdx_toolkit.filter_warc.sources.base import RangeJobSource, CostEstimate +from cdx_toolkit.filter_warc.sources.sql_base import build_sql, validate_result_columns, join_warc_url + + +logger = logging.getLogger(__name__) + +try: + import duckdb + _HAS_DUCKDB = True +except ImportError: # pragma: no cover - exercised in minimal installs + duckdb = None + _HAS_DUCKDB = False + + +def _build_from_clause(index_path: str, crawls: Optional[List[str]]) -> str: + """Build a read_parquet(...) FROM expression over the CC columnar index. + + When crawls are given we glob only those crawl partitions (the pruning lever); + otherwise we glob every crawl (expensive -> the cost guard fires).""" + base = index_path.rstrip('/') + if crawls: + globs = [f"'{base}/crawl={c}/subset=warc/*.parquet'" for c in crawls] + else: + globs = [f"'{base}/crawl=*/subset=warc/*.parquet'"] + return f"read_parquet([{', '.join(globs)}], hive_partitioning=true)" + + +class DuckDbSource(RangeJobSource): + """RangeJobs from a query against the CC columnar index via DuckDB (read_parquet on S3). + + Reads the public CommonCrawl parquet directly; AWS region/credentials come from + the environment (the public bucket is readable with valid credentials).""" + + def __init__( + self, + *, + query: Optional[str] = None, + hostnames: Optional[List[str]] = None, + domains: Optional[List[str]] = None, + crawls: Optional[List[str]] = None, + index_path: str, + warc_download_prefix: Optional[str], + limit: int = 0, + region_name: str = 'us-east-1', + ): + self.raw_query = query + self.hostnames = hostnames + self.domains = domains + self.crawls = crawls + self.index_path = index_path + self.warc_download_prefix = warc_download_prefix + self.limit = limit + self.region_name = region_name + + def estimate_cost(self) -> CostEstimate: + if self.raw_query is not None: + return CostEstimate(n_crawls=None, engine='duckdb') + return CostEstimate(n_crawls=len(self.crawls) if self.crawls else None, engine='duckdb') + + def _build_query(self) -> str: + if self.raw_query is not None: + return self.raw_query + from_clause = _build_from_clause(self.index_path, self.crawls) + # crawl pruning is done in the FROM glob, so no crawl IN (...) in the WHERE + return build_sql( + from_clause, self.hostnames, crawls=None, limit=self.limit, + url_host_registered_domains=self.domains, + ) + + def iter_range_jobs(self) -> Iterator[RangeJob]: + if not _HAS_DUCKDB: + raise RuntimeError( + 'DuckDB engine requires optional dependencies. Install cdx_toolkit[duckdb].' + ) + + query = self._build_query() + logger.info('Executing DuckDB query: %s', query) + + con = duckdb.connect() + try: + con.execute('INSTALL httpfs; LOAD httpfs;') + con.execute(f"SET s3_region='{self.region_name}';") + # Be resilient to transient S3 read timeouts on large parquet partitions. + for stmt in ('SET http_timeout=120000;', 'SET http_retries=5;'): + try: + con.execute(stmt) + except Exception as e: # pragma: no cover - setting unsupported on this duckdb + logger.debug('duckdb setting skipped: %s (%r)', stmt, e) + cur = con.execute(query) + + col_names = [d[0] for d in cur.description] + validate_result_columns(col_names) + idx = {name: i for i, name in enumerate(col_names)} + + while True: + rows = cur.fetchmany(1000) + if not rows: + break + for row in rows: + warc_filename = row[idx['warc_filename']] + warc_url = join_warc_url(self.warc_download_prefix, warc_filename) + yield RangeJob( + url=warc_url, + offset=int(row[idx['warc_record_offset']]), + length=int(row[idx['warc_record_length']]), + filename=warc_filename, + ) + finally: + con.close() diff --git a/cdx_toolkit/filter_warc/sources/factory.py b/cdx_toolkit/filter_warc/sources/factory.py new file mode 100644 index 0000000..ba0e098 --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/factory.py @@ -0,0 +1,99 @@ +import logging +from typing import Optional, Tuple + +from cdx_toolkit.filter_warc.cdx_utils import get_cdx_paths +from cdx_toolkit.filter_warc.sources.base import RangeJobSource +from cdx_toolkit.filter_warc.sources.sql_base import build_athena_query, resolve_crawl_names + + +logger = logging.getLogger(__name__) + + +def make_source(args, *, warc_download_prefix: Optional[str], record_limit: int) -> RangeJobSource: + """Build the RangeJobSource selected by --target-source (+ --engine for sql). + + Centralises all source/engine validation (engine required iff sql; + hostnames/query/query-file mutual exclusivity; required connection options).""" + target = args.target_source + + if target == 'cdx': + from cdx_toolkit.filter_warc.sources.cdx import CdxSource + cdx_paths = get_cdx_paths(args.cdx_path, args.cdx_glob) + return CdxSource(cdx_paths, warc_download_prefix) + + if target == 'csv': + from cdx_toolkit.filter_warc.sources.csv import CsvSource + if not args.csv_path: + raise ValueError('--csv-path is required for --target-source csv') + return CsvSource(args.csv_path, warc_download_prefix) + + if target == 'sql': + return _make_sql_source(args, warc_download_prefix, record_limit) + + raise ValueError(f'Invalid target source: {target} (available: cdx, sql, csv)') + + +def _resolve_sql_query_spec(args) -> Tuple[Optional[str], Optional[list], Optional[list], Optional[list]]: + """Validate the query-defining flags and return (raw_sql, hostnames, domains, crawls). + + The guided path (--hostnames and/or --domains) and a raw query + (--query/--query-file) are mutually exclusive. For the guided path, --crawl is + resolved to concrete crawl names.""" + raw_sql = args.query + if args.query_file: + if raw_sql: + raise ValueError('--query and --query-file are mutually exclusive') + with open(args.query_file) as f: + raw_sql = f.read() + + has_guided = bool(args.hostnames) or bool(args.domains) + if raw_sql and has_guided: + raise ValueError('--query/--query-file are mutually exclusive with --hostnames/--domains') + if not raw_sql and not has_guided: + raise ValueError('the sql target requires --hostnames, --domains, or --query/--query-file') + + if raw_sql: + return raw_sql, None, None, None + + crawls = resolve_crawl_names(args.crawl) if args.crawl else None + return None, args.hostnames, args.domains, crawls + + +def _make_sql_source(args, warc_download_prefix, record_limit) -> RangeJobSource: + engine = args.engine + if not engine: + raise ValueError('--engine is required for --target-source sql (choices: athena, duckdb)') + + raw_sql, hostnames, domains, crawls = _resolve_sql_query_spec(args) + limit = 0 if record_limit is None else record_limit + + if engine == 'athena': + from cdx_toolkit.filter_warc.sources.athena import AthenaSource + if not args.athena_s3_output: + raise ValueError('--athena-s3-output is required for --engine athena') + database = args.athena_database or 'ccindex' + query = raw_sql if raw_sql else build_athena_query( + hostnames, crawls=crawls, limit=limit, url_host_registered_domains=domains, + ) + n_crawls = None if raw_sql else (len(crawls) if crawls else None) + return AthenaSource( + query=query, + database=database, + s3_output_location=args.athena_s3_output, + warc_download_prefix=warc_download_prefix, + n_crawls=n_crawls, + ) + + if engine == 'duckdb': + from cdx_toolkit.filter_warc.sources.duckdb import DuckDbSource + return DuckDbSource( + query=raw_sql, + hostnames=hostnames, + domains=domains, + crawls=crawls, + index_path=args.duckdb_index_path, + warc_download_prefix=warc_download_prefix, + limit=limit, + ) + + raise ValueError(f'Invalid --engine: {engine} (choices: athena, duckdb)') diff --git a/cdx_toolkit/filter_warc/sources/sql_base.py b/cdx_toolkit/filter_warc/sources/sql_base.py new file mode 100644 index 0000000..e466682 --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/sql_base.py @@ -0,0 +1,150 @@ +import logging +import re +from typing import List, Optional + +from cdx_toolkit.commoncrawl import normalize_crawl, get_cc_endpoints, match_cc_crawls + + +logger = logging.getLogger(__name__) + +# Required output columns that any index query (built or raw) must provide. +REQUIRED_RESULT_COLUMNS = ('warc_filename', 'warc_record_offset', 'warc_record_length') + +# Hostnames, TLDs and crawl names only ever contain these characters. Validating +# against this set both prevents SQL injection and catches malformed input early. +_SQL_LITERAL_RE = re.compile(r'^[A-Za-z0-9.\-]+$') + +# Default Common Crawl index mirror used to resolve --crawl to concrete crawl names. +CC_INDEX_MIRROR = 'https://index.commoncrawl.org/' + + +def escape_sql_literal(value: str) -> str: + """Validate and quote a value for safe inclusion in a SQL string literal. + + Only hostname/TLD/crawl-name characters (letters, digits, dot, hyphen) are + allowed, so the result cannot break out of the quotes or inject SQL.""" + if not isinstance(value, str) or not _SQL_LITERAL_RE.match(value): + raise ValueError( + f'Invalid value for SQL query literal: {value!r} ' + '(allowed characters: letters, digits, dot, hyphen)' + ) + return "'" + value + "'" + + +def build_where_sql( + url_host_names: Optional[List[str]] = None, + crawls: Optional[List[str]] = None, + url_host_registered_domains: Optional[List[str]] = None, +) -> str: + """Build the WHERE body (without the `WHERE` keyword) shared by all SQL engines. + + The guided filter matches on `url_host_name` (exact host, e.g. www.example.com) + and/or `url_host_registered_domain` (e.g. example.com, which also covers its + subdomains); predicates are OR-ed together. At least one host or domain is + required. If `crawls` is a non-empty list of crawl names (e.g. ['CC-MAIN-2025-33']), + a `crawl IN (...)` partition filter is added -- the main lever for reducing scan + cost. Engines differ only in their FROM clause (see build_sql).""" + url_host_names = url_host_names or [] + domains = url_host_registered_domains or [] + if not url_host_names and not domains: + raise ValueError('an index query requires at least one hostname or registered domain') + + tlds = sorted({v.split('.')[-1] for v in (list(url_host_names) + list(domains))}) + query_tlds = ' OR '.join(f'url_host_tld = {escape_sql_literal(t)}' for t in tlds) + + host_predicates = [f'url_host_name = {escape_sql_literal(h)}' for h in url_host_names] + host_predicates += [f'url_host_registered_domain = {escape_sql_literal(d)}' for d in domains] + query_hosts = ' OR '.join(host_predicates) + + clauses = [ + "subset = 'warc'", + f'({query_tlds}) -- help the query optimizer', + f'({query_hosts})', + ] + # TODO wire --from/--to into a fetch_time BETWEEN ... clause here + if crawls: + crawl_in = ', '.join(escape_sql_literal(c) for c in crawls) + clauses.append(f'crawl IN ({crawl_in})') + + return '\n AND '.join(clauses) + + +def build_sql( + from_clause: str, + url_host_names: Optional[List[str]] = None, + crawls: Optional[List[str]] = None, + limit: int = 0, + url_host_registered_domains: Optional[List[str]] = None, +) -> str: + """Assemble a full SELECT for the columnar index. + + `from_clause` is the text following FROM (e.g. `ccindex` for Athena, or a + `read_parquet(...)` expression for DuckDB).""" + where_sql = build_where_sql(url_host_names, crawls, url_host_registered_domains=url_host_registered_domains) + limit_sql = f'\n LIMIT {limit}' if limit and limit > 0 else '' + + return f""" + SELECT + warc_filename, warc_record_offset, warc_record_length + FROM {from_clause} + WHERE {where_sql}{limit_sql}""" + + +def build_athena_query( + url_host_names: Optional[List[str]] = None, + crawls: Optional[List[str]] = None, + limit: int = 0, + table: str = 'ccindex', + url_host_registered_domains: Optional[List[str]] = None, +) -> str: + """Athena flavour of build_sql (FROM