From 0c755ad4786fc600639c72fa4fedcd0327295af5 Mon Sep 17 00:00:00 2001 From: malteos Date: Fri, 12 Jun 2026 11:45:34 +0200 Subject: [PATCH 1/6] feat: pluggable RangeJob sources for repackage (cdx/sql{athena,duckdb}/csv) Rename the warc_by_cdx command to 'repackage' and make the range-job stage pluggable behind a RangeJobSource abstraction: - sources/: base (RangeJobSource + CostEstimate), sql_base (shared query builder + crawl resolution), cdx, athena, duckdb (optional dep), csv (reader + RangeJobCsvWriter), and a make_source factory. - CLI: --target-source {cdx,sql,csv} with --engine {athena,duckdb}; shared --hostnames/--query/--query-file; --duckdb-index-path; --csv-path; CSV materialization via --range-jobs-output/--no-fetch/--csv-self-contained; generalized --confirm-cost guard. - WARCFilter: takes an injected source; orchestrator owns queueing, the record limit, counting, and _STOP emission in a finally (fixes the prior hung-readers bug when a source raised). Sources own their own stage-1 client/connection; WARCFilter manages only read/write S3 clients. - DuckDB reads the CC columnar parquet directly via read_parquet with per-crawl partition globbing; Athena unchanged in behaviour. - setup.py: cdx_toolkit[duckdb] extra; conftest: requires_duckdb. Tests: unit (sql_base, csv round-trip, make_source, confirm_cost, producer stop-sentinel regression, --no-fetch); CSV round-trip e2e from the CDX fixture; gated Athena/DuckDB e2e (CC-MAIN-2026-17/commoncrawl.org, single partition). DuckDB e2e verified live; 205 passed. --- cdx_toolkit/cli.py | 14 +- cdx_toolkit/filter_warc/args.py | 79 ++++-- .../filter_warc/athena_job_generator.py | 251 ------------------ cdx_toolkit/filter_warc/cdx_utils.py | 6 +- cdx_toolkit/filter_warc/command.py | 152 +++-------- cdx_toolkit/filter_warc/data_classes.py | 6 +- cdx_toolkit/filter_warc/sources/__init__.py | 4 + cdx_toolkit/filter_warc/sources/athena.py | 177 ++++++++++++ cdx_toolkit/filter_warc/sources/base.py | 35 +++ cdx_toolkit/filter_warc/sources/cdx.py | 28 ++ cdx_toolkit/filter_warc/sources/csv.py | 90 +++++++ cdx_toolkit/filter_warc/sources/duckdb.py | 102 +++++++ cdx_toolkit/filter_warc/sources/factory.py | 94 +++++++ cdx_toolkit/filter_warc/sources/sql_base.py | 133 ++++++++++ cdx_toolkit/filter_warc/warc_filter.py | 251 +++++++----------- requirements.txt | 3 + setup.py | 7 +- tests/conftest.py | 13 + .../test_athena_command_validation.py | 84 ------ .../filter_warc/test_athena_job_generator.py | 74 ------ tests/filter_warc/test_athena_prompt.py | 29 +- .../filter_warc/test_athena_query_builder.py | 4 +- tests/filter_warc/test_cdx_utils.py | 8 +- tests/filter_warc/test_command.py | 132 ++++++++- tests/filter_warc/test_csv_source.py | 66 +++++ tests/filter_warc/test_grouped_range_jobs.py | 6 +- tests/filter_warc/test_make_source.py | 132 +++++++++ tests/filter_warc/test_producer.py | 68 +++++ tests/filter_warc/test_sql_sources_gated.py | 86 ++++++ tests/filter_warc/test_warc_filter.py | 30 ++- 30 files changed, 1413 insertions(+), 751 deletions(-) delete mode 100644 cdx_toolkit/filter_warc/athena_job_generator.py create mode 100644 cdx_toolkit/filter_warc/sources/__init__.py create mode 100644 cdx_toolkit/filter_warc/sources/athena.py create mode 100644 cdx_toolkit/filter_warc/sources/base.py create mode 100644 cdx_toolkit/filter_warc/sources/cdx.py create mode 100644 cdx_toolkit/filter_warc/sources/csv.py create mode 100644 cdx_toolkit/filter_warc/sources/duckdb.py create mode 100644 cdx_toolkit/filter_warc/sources/factory.py create mode 100644 cdx_toolkit/filter_warc/sources/sql_base.py delete mode 100644 tests/filter_warc/test_athena_command_validation.py delete mode 100644 tests/filter_warc/test_athena_job_generator.py create mode 100644 tests/filter_warc/test_csv_source.py create mode 100644 tests/filter_warc/test_make_source.py create mode 100644 tests/filter_warc/test_producer.py create mode 100644 tests/filter_warc/test_sql_sources_gated.py diff --git a/cdx_toolkit/cli.py b/cdx_toolkit/cli.py index c650d5c..f74610d 100644 --- a/cdx_toolkit/cli.py +++ b/cdx_toolkit/cli.py @@ -12,8 +12,8 @@ from cdx_toolkit.filter_cdx.command import run_filter_cdx from cdx_toolkit.filter_cdx.args import add_filter_cdx_args -from cdx_toolkit.filter_warc.command import run_warcer_by_cdx -from cdx_toolkit.filter_warc.args import add_warcer_by_cdx_args +from cdx_toolkit.filter_warc.command import run_repackage +from cdx_toolkit.filter_warc.args import add_repackage_args LOGGER = logging.getLogger(__name__) @@ -124,12 +124,12 @@ def main(args=None): warc.add_argument('url') warc.set_defaults(func=warcer) - warc_by_cdx = subparsers.add_parser( - 'warc_by_cdx', - help='iterate over capture content based on an CDX index file, creating a warc' + repackage = subparsers.add_parser( + 'repackage', + help='repackage WARC ranges from a CDX/SQL/CSV source into a new WARC' ) - add_warcer_by_cdx_args(warc_by_cdx) - warc_by_cdx.set_defaults(func=run_warcer_by_cdx) + add_repackage_args(repackage) + repackage.set_defaults(func=run_repackage) filter_cdx = subparsers.add_parser('filter_cdx', help='Filter CDX files based on SURT prefixes whitelist') add_filter_cdx_args(filter_cdx) diff --git a/cdx_toolkit/filter_warc/args.py b/cdx_toolkit/filter_warc/args.py index 1fa3a2e..d238134 100644 --- a/cdx_toolkit/filter_warc/args.py +++ b/cdx_toolkit/filter_warc/args.py @@ -5,12 +5,13 @@ logger = logging.getLogger(__name__) -def add_warcer_by_cdx_args(parser: argparse.ArgumentParser): +def add_repackage_args(parser: argparse.ArgumentParser): + # --- CDX source --- parser.add_argument( '--cdx-path', type=str, default=None, - help='Path to CDX index file (local or remote, e.g. S3). Required if target source is set to `cdx`.', + help='Path to CDX index file (local or remote, e.g. S3). Used when --target-source cdx.', ) parser.add_argument( '--cdx-glob', @@ -18,46 +19,85 @@ def add_warcer_by_cdx_args(parser: argparse.ArgumentParser): default=None, help='a glob pattern for read from multiple CDX indices', ) + # --- SQL source (--target-source sql --engine athena|duckdb) --- parser.add_argument( - '--athena-hostnames', + '--engine', type=str, - nargs="+", default=None, - help=('Hostnames to filter for via Athena (whitelist). Use this OR --athena-query/' - '--athena-query-file (mutually exclusive) when target source is `athena`.'), + choices=['athena', 'duckdb'], + help='SQL engine for the columnar index. Required when --target-source sql.', ) parser.add_argument( - '--athena-query', + '--hostnames', type=str, + nargs='+', default=None, - help=('Raw Athena SQL to run instead of the hostname-based query (power users). The query ' - 'must SELECT the columns warc_filename, warc_record_offset, warc_record_length. ' - 'Mutually exclusive with --athena-hostnames and --athena-query-file.'), + help=('Hostnames to filter for (whitelist) via the SQL index. Use this OR ' + '--query/--query-file (mutually exclusive). Combine with the global --crawl to ' + 'restrict the scan to specific crawls (strongly recommended for cost).'), ) parser.add_argument( - '--athena-query-file', + '--query', type=str, default=None, - help='Path to a file containing the raw Athena SQL (alternative to --athena-query).', + help=('Raw SQL to run instead of the hostname-based query (power users). Must SELECT the ' + 'columns warc_filename, warc_record_offset, warc_record_length. Engine-specific ' + 'dialect. Mutually exclusive with --hostnames and --query-file.'), + ) + parser.add_argument( + '--query-file', + type=str, + default=None, + help='Path to a file containing the raw SQL (alternative to --query).', ) parser.add_argument( '--athena-database', type=str, default=None, - help='Athena database. Required if target source is set to `athena`.', + help='Athena database (engine=athena). Defaults to `ccindex`.', ) parser.add_argument( '--athena-s3-output', type=str, default=None, - help='Athena S3 output location. Required if target source is set to `athena`.', + help='Athena S3 output location (engine=athena). Required for engine=athena.', ) parser.add_argument( - '--confirm-athena-cost', + '--duckdb-index-path', + type=str, + default='s3://commoncrawl/cc-index/table/cc-main/warc/', + help='Base S3 path to the CC columnar index parquet (engine=duckdb).', + ) + parser.add_argument( + '--confirm-cost', action='store_true', - help=('Skip the Athena cost-confirmation prompt and run even unpartitioned / large-scan ' + help=('Skip the cost-confirmation prompt and run even unpartitioned / large-scan SQL ' 'queries. Athena bills per TB scanned; restrict with --crawl to reduce cost.'), ) + # --- CSV source --- + parser.add_argument( + '--csv-path', + type=str, + default=None, + help='Path to a range-jobs CSV/TSV (local or remote). Used when --target-source csv.', + ) + # --- Range-jobs materialization (any source) --- + parser.add_argument( + '--range-jobs-output', + type=str, + default=None, + help='If set, write each generated RangeJob to this CSV (filename,offset,length by default).', + ) + parser.add_argument( + '--no-fetch', + action='store_true', + help='Only generate range jobs (write --range-jobs-output); skip fetching/writing WARCs.', + ) + parser.add_argument( + '--csv-self-contained', + action='store_true', + help='Write full URLs (url,offset,length) to --range-jobs-output instead of relative filenames.', + ) parser.add_argument('--prefix', default='TEST', help='prefix for the output warc filename') parser.add_argument( '--subprefix', @@ -126,8 +166,9 @@ def add_warcer_by_cdx_args(parser: argparse.ArgumentParser): '--target-source', action='store', default='cdx', - help=('Source from that the filter targets are loaded (available options: `cdx`, `athena`; ' - 'defaults to `cdx`). For `athena`, use the global --crawl to restrict the scan to ' - 'specific crawls (strongly recommended; Athena bills per TB scanned).'), + choices=['cdx', 'sql', 'csv'], + help=('Where range jobs come from: `cdx` (index files), `sql` (columnar index via ' + '--engine athena|duckdb), or `csv` (a range-jobs CSV). Defaults to `cdx`. For `sql`, ' + 'use the global --crawl to restrict the scan to specific crawls (recommended for cost).'), ) return parser diff --git a/cdx_toolkit/filter_warc/athena_job_generator.py b/cdx_toolkit/filter_warc/athena_job_generator.py deleted file mode 100644 index c4c4323..0000000 --- a/cdx_toolkit/filter_warc/athena_job_generator.py +++ /dev/null @@ -1,251 +0,0 @@ -import asyncio -import logging -import re -import time -from typing import Any, Iterable, List, Optional - -from cdx_toolkit.filter_warc.data_classes import RangeJob - - -logger = logging.getLogger(__name__) - -# Required output columns that any Athena query (built or raw) must provide. -REQUIRED_RESULT_COLUMNS = ('warc_filename', 'warc_record_offset', 'warc_record_length') - -# Athena pricing is ~$5 per TB scanned (used for the post-run cost estimate). -ATHENA_USD_PER_TB = 5.0 - -# Hostnames, TLDs and crawl names only ever contain these characters. Validating -# against this set both prevents SQL injection and catches malformed input early. -_SQL_LITERAL_RE = re.compile(r'^[A-Za-z0-9.\-]+$') - - -def escape_sql_literal(value: str) -> str: - """Validate and quote a value for safe inclusion in an Athena SQL string literal. - - Only hostname/TLD/crawl-name characters (letters, digits, dot, hyphen) are - allowed, so the result cannot break out of the quotes or inject SQL.""" - if not isinstance(value, str) or not _SQL_LITERAL_RE.match(value): - raise ValueError( - f'Invalid value for Athena query literal: {value!r} ' - '(allowed characters: letters, digits, dot, hyphen)' - ) - return "'" + value + "'" - - -def build_athena_query( - url_host_names: List[str], - crawls: Optional[List[str]] = None, - limit: int = 0, - table: str = 'ccindex', -) -> str: - """Build the Athena SQL returning warc_filename/offset/length for the hostnames. - - CommonCrawl provides an index via AWS Athena that we can use to find the file - names, offsets, and byte lengths for WARC filtering. See - https://commoncrawl.org/blog/index-to-warc-files-and-urls-in-columnar-format - - If `crawls` is a non-empty list of crawl names (e.g. ['CC-MAIN-2025-33']), a - `crawl IN (...)` partition filter is added -- this is the main lever for - reducing Athena scan cost.""" - if not url_host_names: - raise ValueError('build_athena_query requires at least one hostname') - - tlds = sorted({url.split('.')[-1] for url in url_host_names}) - query_tlds = ' OR '.join(f'url_host_tld = {escape_sql_literal(tld)}' for tld in tlds) - query_hostnames = ' OR '.join(f'url_host_name = {escape_sql_literal(h)}' for h in url_host_names) - - where_clauses = [ - "subset = 'warc'", - f'({query_tlds}) -- help the query optimizer', - f'({query_hostnames})', - ] - # TODO wire --from/--to into a fetch_time BETWEEN ... clause here - if crawls: - crawl_in = ', '.join(escape_sql_literal(c) for c in crawls) - where_clauses.append(f'crawl IN ({crawl_in})') - - where_sql = '\n AND '.join(where_clauses) - query_limit = f'\n LIMIT {limit}' if limit > 0 else '' - - return f""" - SELECT - warc_filename, warc_record_offset, warc_record_length - FROM {table} - WHERE {where_sql}{query_limit}""" - - -def validate_result_columns(column_names) -> None: - """Raise a clear error if the query result lacks the required columns.""" - missing = [c for c in REQUIRED_RESULT_COLUMNS if c not in (column_names or [])] - if missing: - raise ValueError( - 'Athena query result is missing required columns: ' + ', '.join(missing) + - '. The query (including a raw --athena-query) must SELECT ' + - ', '.join(REQUIRED_RESULT_COLUMNS) + '.' - ) - - -def join_warc_url(prefix: Optional[str], warc_filename: str) -> str: - """Join a download prefix with a warc_filename robustly. - - - if warc_filename is already absolute (contains '://'), return it unchanged - (supports custom queries whose warc_filename is a full s3://-/https:// URL); - - if prefix is empty/None, return warc_filename unchanged; - - otherwise join with exactly one '/' (no double slash, no missing slash).""" - if '://' in warc_filename: - return warc_filename - if not prefix: - return warc_filename - return prefix.rstrip('/') + '/' + warc_filename.lstrip('/') - - -def run_athena_query(client, query: str, database: str, s3_output_location: str, max_wait_time: int = 300) -> str: - """Start an Athena query and block until it completes; return the execution id. - - Raises if the query does not reach the SUCCEEDED state. If the wait is - interrupted (Ctrl-C / cancellation) or times out, the query is cancelled - server-side so we don't keep paying for a scan whose results we'll never read.""" - logger.info('Executing Athena query: %s', query) - - response = client.start_query_execution( - QueryString=query, - QueryExecutionContext={'Database': database}, - ResultConfiguration={'OutputLocation': s3_output_location}, - ) - - query_execution_id = response['QueryExecutionId'] - logger.info('Query execution started. ID: %s', query_execution_id) - - try: - status = _wait_for_query_completion(client, query_execution_id, max_wait_time) - except BaseException: - # Ctrl-C, asyncio cancellation, or timeout: stop the query server-side - # to bound the Athena scan cost, then propagate the original exception. - _stop_query(client, query_execution_id) - raise - - if status != 'SUCCEEDED': - raise Exception(f'Query failed with status: {status}') - - return query_execution_id - - -def _stop_query(client, query_execution_id: str) -> None: - """Best-effort cancellation of a running Athena query.""" - logger.warning('Cancelling Athena query %s ...', query_execution_id) - try: - client.stop_query_execution(QueryExecutionId=query_execution_id) - logger.warning('Athena query %s cancelled', query_execution_id) - except Exception as e: # pragma: no cover - best-effort cleanup - logger.warning('Failed to cancel Athena query %s: %r', query_execution_id, e) - - -def report_query_cost(client, query_execution_id: str) -> None: - """Log the bytes scanned and an estimated USD cost for a completed query.""" - try: - response = client.get_query_execution(QueryExecutionId=query_execution_id) - scanned = response['QueryExecution'].get('Statistics', {}).get('DataScannedInBytes') - except Exception as e: # pragma: no cover - best-effort reporting - logger.debug('unable to read Athena query statistics: %r', e) - return - - if scanned is None: - return - - gb = scanned / 1e9 - usd = scanned / 1e12 * ATHENA_USD_PER_TB - logger.info('Athena scanned %.2f GB, estimated cost ~$%.4f', gb, usd) - - -async def get_range_jobs_from_athena( - client, - query: str, - database: str, - s3_output_location: str, - job_queue: asyncio.Queue, - queue_stop_object: Any, - warc_download_prefix: str, - num_fetchers: int, - max_wait_time: int = 300, -) -> int: - """Execute a prepared Athena query and enqueue a RangeJob per result row. - - The query string is built and validated by the caller (see build_athena_query - and cdx_toolkit.filter_warc.command). This function only executes it, maps the - results to RangeJob objects, pushes them to the asyncio queue, signals the - fetchers to stop, and logs the scan cost.""" - count = 0 - - query_execution_id = run_athena_query(client, query, database, s3_output_location, max_wait_time) - - for range_job in iter_range_jobs(client, query_execution_id, warc_download_prefix): - await job_queue.put(range_job) - count += 1 - - report_query_cost(client, query_execution_id) - - # Signal fetchers to stop - for _ in range(num_fetchers): - await job_queue.put(queue_stop_object) - - logger.info('Athena query enqueued %d jobs', count) - - return count - - -def _wait_for_query_completion(client, query_execution_id: str, max_wait_time: int) -> str: - """Wait for query to complete and return final status""" - start_time = time.time() - - while time.time() - start_time < max_wait_time: - response = client.get_query_execution(QueryExecutionId=query_execution_id) - - status = response['QueryExecution']['Status']['State'] - logger.info(f'Query status: {status}') - - if status in ['SUCCEEDED', 'FAILED', 'CANCELLED']: - if status == 'FAILED': - error_reason = response['QueryExecution']['Status'].get('StateChangeReason', 'Unknown error') - logger.info(f'Query failed: {error_reason}') - return status - - time.sleep(2) - - raise TimeoutError(f'Query did not complete within {max_wait_time} seconds') - - -def iter_range_jobs(client, query_execution_id: str, warc_download_prefix: str) -> Iterable[RangeJob]: - """Retrieve query results and convert each row to a RangeJob""" - # Get query results - paginator = client.get_paginator('get_query_results') - page_iterator = paginator.paginate(QueryExecutionId=query_execution_id) - column_names = None - - for page in page_iterator: - rows = page['ResultSet']['Rows'] - - # Get column names from first page - if column_names is None and rows: - column_names = [col['VarCharValue'] for col in rows[0]['Data']] - validate_result_columns(column_names) - rows = rows[1:] # Skip header row - - # Process data rows - for row in rows: - row_data = [] - for cell in row['Data']: - value = cell.get('VarCharValue', None) - row_data.append(value) - - row = dict(zip(column_names, row_data)) - - warc_url = join_warc_url(warc_download_prefix, row['warc_filename']) - - yield RangeJob(url=warc_url, offset=int(row['warc_record_offset']), length=int(row['warc_record_length'])) - - -def get_databases(client) -> list: - """Get list of available databases""" - response = client.list_databases(CatalogName='AwsDataCatalog') - return [db['Name'] for db in response['DatabaseList']] diff --git a/cdx_toolkit/filter_warc/cdx_utils.py b/cdx_toolkit/filter_warc/cdx_utils.py index 6b0c584..45f32a2 100644 --- a/cdx_toolkit/filter_warc/cdx_utils.py +++ b/cdx_toolkit/filter_warc/cdx_utils.py @@ -27,7 +27,7 @@ def get_index_as_string_from_path( return f.read() -def read_cdx_line(line: str, warc_download_prefix: str) -> Tuple[str, int, int]: +def read_cdx_line(line: str, warc_download_prefix: str) -> Tuple[str, int, int, str]: cols = line.split(' ', maxsplit=2) if len(cols) == 3: @@ -49,10 +49,10 @@ def read_cdx_line(line: str, warc_download_prefix: str) -> Tuple[str, int, int]: warc_url = warc_download_prefix + '/' + filename - return (warc_url, offset, length) + return (warc_url, offset, length, filename) -def iter_cdx_index_from_path(index_path: str, warc_download_prefix: str) -> Iterable[Tuple[str, int, int]]: +def iter_cdx_index_from_path(index_path: str, warc_download_prefix: str) -> Iterable[Tuple[str, int, int, str]]: """ Iterate CDX records from a file path (gzipped; local or remote). """ diff --git a/cdx_toolkit/filter_warc/command.py b/cdx_toolkit/filter_warc/command.py index 2328f64..8a2d187 100644 --- a/cdx_toolkit/filter_warc/command.py +++ b/cdx_toolkit/filter_warc/command.py @@ -1,7 +1,5 @@ -from cdx_toolkit.filter_warc.cdx_utils import get_cdx_paths from cdx_toolkit.filter_warc.warc_filter import WARCFilter -from cdx_toolkit.filter_warc.athena_job_generator import build_athena_query -from cdx_toolkit.commoncrawl import normalize_crawl, get_cc_endpoints, match_cc_crawls +from cdx_toolkit.filter_warc.sources import make_source from cdx_toolkit.utils import get_version @@ -14,116 +12,54 @@ logger = logging.getLogger(__name__) -# A built Athena query restricted to at most this many crawls is considered cheap -# enough to run without a cost-confirmation prompt. +# A SQL query restricted to at most this many crawls is considered cheap enough to +# run without a cost-confirmation prompt. LARGE_CRAWL_SET_THRESHOLD = 10 -# Default Common Crawl index mirror used to resolve --crawl to concrete crawl names. -CC_INDEX_MIRROR = 'https://index.commoncrawl.org/' +def confirm_cost(estimate, confirmed) -> None: + """Prompt before running a potentially expensive index scan. -def _endpoint_to_crawl_name(endpoint: str) -> str: - """Turn a collinfo cdx-api endpoint into its crawl name. - - e.g. 'https://index.commoncrawl.org/CC-MAIN-2025-33-index' -> 'CC-MAIN-2025-33'.""" - name = endpoint.rstrip('/').split('/')[-1] - if name.endswith('-index'): - name = name[:-len('-index')] - return name - - -def _resolve_crawl_names(crawl_arg) -> list: - """Resolve a --crawl value to concrete CC-MAIN crawl names for the partition filter. - - Reuses the CDX path's helpers so `--crawl` accepts the same forms (comma-separated - names or an integer for the most recent N crawls).""" - crawls = normalize_crawl([crawl_arg]) - raw_index_list = get_cc_endpoints(CC_INDEX_MIRROR) - matched = match_cc_crawls(crawls, raw_index_list) - return [_endpoint_to_crawl_name(ep) for ep in matched] - - -def resolve_athena_query(args): - """Validate the Athena args and return (sql, n_crawls). - - n_crawls is the number of crawls the query is restricted to, or None when the - query is unrestricted (scans all crawls) or its pruning cannot be verified - (raw --athena-query/--athena-query-file).""" - raw_sql = args.athena_query - if args.athena_query_file: - if raw_sql: - raise ValueError('--athena-query and --athena-query-file are mutually exclusive') - with open(args.athena_query_file) as f: - raw_sql = f.read() - - if raw_sql and args.athena_hostnames: - raise ValueError('--athena-query/--athena-query-file are mutually exclusive with --athena-hostnames') - if not raw_sql and not args.athena_hostnames: - raise ValueError('athena target requires either --athena-hostnames or --athena-query/--athena-query-file') - - if not args.athena_database: - raise ValueError('--athena-database is required for target source `athena`') - if not args.athena_s3_output: - raise ValueError('--athena-s3-output is required for target source `athena`') - - if raw_sql: - # Crawl-partition pruning of a raw query cannot be verified -> treat as unbounded. - return raw_sql, None - - # Guided/built path - limit = 0 if args.limit is None else args.limit - if args.crawl: - crawl_names = _resolve_crawl_names(args.crawl) - sql = build_athena_query(args.athena_hostnames, crawls=crawl_names, limit=limit) - return sql, len(crawl_names) - - sql = build_athena_query(args.athena_hostnames, limit=limit) - return sql, None - - -def confirm_athena_cost(n_crawls, confirmed) -> None: - """Prompt before running a potentially expensive Athena query. - - A built query restricted to <= LARGE_CRAWL_SET_THRESHOLD crawls runs without a - prompt. Otherwise (no crawl filter, a large crawl set, or unverifiable raw SQL) - we confirm interactively, abort in non-interactive sessions, unless `confirmed` - (--confirm-athena-cost) is set.""" - if confirmed: + `estimate` is a CostEstimate (or None for sources that never bill, e.g. cdx/csv). + A scan bounded to <= LARGE_CRAWL_SET_THRESHOLD crawls runs without a prompt. + Otherwise (no crawl filter, a large crawl set, or unverifiable raw SQL) we confirm + interactively, abort in non-interactive sessions, unless `confirmed` (--confirm-cost).""" + if confirmed or estimate is None: return + + n_crawls = estimate.n_crawls if n_crawls is not None and n_crawls <= LARGE_CRAWL_SET_THRESHOLD: return + engine = estimate.engine if n_crawls is None: - reason = ('This Athena query is not restricted to specific crawls (or uses custom SQL whose ' - 'crawl partition pruning could not be verified) and may scan ALL crawls.') + reason = (f'This {engine} query is not restricted to specific crawls (or uses custom SQL ' + 'whose crawl partition pruning could not be verified) and may scan ALL crawls.') else: - reason = (f'This Athena query is restricted to {n_crawls} crawls (more than ' + reason = (f'This {engine} query is restricted to {n_crawls} crawls (more than ' f'{LARGE_CRAWL_SET_THRESHOLD}) and may scan a large amount of data.') if not sys.stdin.isatty(): raise SystemExit( - reason + ' Refusing to run a potentially expensive Athena scan in non-interactive mode. ' - 'Restrict with --crawl (<=10 crawls), or pass --confirm-athena-cost.' + reason + ' Refusing to run a potentially expensive scan in non-interactive mode. ' + 'Restrict with --crawl (<=10 crawls), or pass --confirm-cost.' ) logger.warning(reason) - answer = input('Athena bills per TB scanned. Proceed? [y/N] ') + answer = input('Index SQL scans can be expensive (Athena bills per TB scanned). Proceed? [y/N] ') if answer.strip().lower() not in ('y', 'yes'): raise SystemExit('Aborted by user.') -def run_warcer_by_cdx(args, cmdline): - """Like warcer but fetches WARC records based on one or more CDX index files. - - The CDX files can be filtered using the `filter_cdx` commands based a given URL/SURT list. +def run_repackage(args, cmdline): + """Repackage WARC records from a pluggable range-job source (cdx / sql / csv). Approach: - - Iterate over one or more CDX files to extract capture object (file, offset, length) - - Fetch WARC record based on capture object - - Write to new WARC file including metadata records with index. - - The CDX metadata record is written to the WARC directly before for response records that matches to the CDX. + - Generate RangeJobs (WARC file + byte range) from the selected source. + - Optionally materialize them to a CSV (--range-jobs-output; --no-fetch to skip fetching). + - Fetch each WARC record and write a new WARC, including metadata records. """ - logger.info('Filtering WARC files based on CDX') + logger.info('Repackaging WARC files (target source: %s)', args.target_source) # Start timing start_time = time.time() @@ -142,7 +78,7 @@ def run_warcer_by_cdx(args, cmdline): 'isPartOf': ispartof, 'description': args.description if args.description - else 'warc extraction based on CDX generated with: ' + cmdline, + else 'warc extraction generated with: ' + cmdline, 'format': 'WARC file version 1.0', } if args.creator: @@ -154,33 +90,22 @@ def run_warcer_by_cdx(args, cmdline): log_every_n = args.log_every_n limit = 0 if args.limit is None else args.limit prefix_path = str(args.prefix) - prefix_fs, prefix_fs_path = fsspec.url_to_fs(prefix_path) - # make sure the base dir exists - prefix_fs.makedirs(prefix_fs._parent(prefix_fs_path), exist_ok=True) + # Build the source (validates source/engine/query options) and confirm scan cost + # up front (synchronously) before launching the pipeline. + source = make_source(args, warc_download_prefix=args.warc_download_prefix, record_limit=limit) + confirm_cost(source.estimate_cost(), args.confirm_cost) - # target source handling - athena_query = None - if args.target_source == 'cdx': - cdx_paths = get_cdx_paths( - args.cdx_path, - args.cdx_glob, - ) - elif args.target_source == "athena": - cdx_paths = None - # Build/validate the Athena query up front (synchronously) so we can warn - # about expensive (unpartitioned / large) scans before launching the pipeline. - athena_query, n_crawls = resolve_athena_query(args) - confirm_athena_cost(n_crawls, args.confirm_athena_cost) - else: - raise ValueError(f'Invalid target source specified: {args.target_source} (available: cdx, athena)') + # make sure the output base dir exists (only needed when actually writing WARCs) + if not args.no_fetch: + prefix_fs, prefix_fs_path = fsspec.url_to_fs(prefix_path) + prefix_fs.makedirs(prefix_fs._parent(prefix_fs_path), exist_ok=True) warc_filter = WARCFilter( - target_source=args.target_source, - cdx_paths=cdx_paths, - athena_database=args.athena_database, - athena_s3_output_location=args.athena_s3_output, - athena_query=athena_query, + source=source, + range_jobs_output=args.range_jobs_output, + no_fetch=args.no_fetch, + csv_self_contained=args.csv_self_contained, prefix_path=prefix_path, writer_info=info, writer_subprefix=args.subprefix, @@ -190,7 +115,6 @@ def run_warcer_by_cdx(args, cmdline): warc_download_prefix=args.warc_download_prefix, n_parallel=n_parallel, max_file_size=args.size, - # writer_kwargs=writer_kwargs, ) records_n = warc_filter.filter() diff --git a/cdx_toolkit/filter_warc/data_classes.py b/cdx_toolkit/filter_warc/data_classes.py index 8f36325..40f190e 100644 --- a/cdx_toolkit/filter_warc/data_classes.py +++ b/cdx_toolkit/filter_warc/data_classes.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from cdx_toolkit.filter_warc.s3_utils import is_s3_url, parse_s3_uri, with_retries -from typing import Tuple +from typing import Optional, Tuple from cdx_toolkit.myrequests import myrequests_get @@ -47,6 +47,10 @@ class RangeJob: offset: int length: int records_count: int = 1 + # Relative WARC filename (e.g. crawl-data/...warc.gz) as known by the source, + # used when materializing a non-self-contained range-jobs CSV. `url` stays + # authoritative for fetching. + filename: Optional[str] = None def is_s3(self): return is_s3_url(self.url) diff --git a/cdx_toolkit/filter_warc/sources/__init__.py b/cdx_toolkit/filter_warc/sources/__init__.py new file mode 100644 index 0000000..d1208cf --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/__init__.py @@ -0,0 +1,4 @@ +from cdx_toolkit.filter_warc.sources.base import RangeJobSource, CostEstimate +from cdx_toolkit.filter_warc.sources.factory import make_source + +__all__ = ['RangeJobSource', 'CostEstimate', 'make_source'] diff --git a/cdx_toolkit/filter_warc/sources/athena.py b/cdx_toolkit/filter_warc/sources/athena.py new file mode 100644 index 0000000..203048d --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/athena.py @@ -0,0 +1,177 @@ +import logging +import time +from typing import Iterator, Iterable, Optional + +from cdx_toolkit.filter_warc.data_classes import RangeJob +from cdx_toolkit.filter_warc.sources.base import RangeJobSource, CostEstimate +from cdx_toolkit.filter_warc.sources.sql_base import validate_result_columns, join_warc_url + + +logger = logging.getLogger(__name__) + +# Athena pricing is ~$5 per TB scanned (used for the post-run cost estimate). +ATHENA_USD_PER_TB = 5.0 + + +class AthenaSource(RangeJobSource): + """RangeJobs from a query against the CC columnar index via AWS Athena.""" + + def __init__( + self, + *, + query: str, + database: str, + s3_output_location: str, + warc_download_prefix: Optional[str], + n_crawls: Optional[int] = None, + region_name: str = 'us-east-1', + max_wait_time: int = 300, + ): + self.query = query + self.database = database + self.s3_output_location = s3_output_location + self.warc_download_prefix = warc_download_prefix + self._n_crawls = n_crawls + self.region_name = region_name + self.max_wait_time = max_wait_time + + def estimate_cost(self) -> CostEstimate: + return CostEstimate(n_crawls=self._n_crawls, engine='athena') + + def _make_client(self): + import boto3 + from botocore.config import Config + + config = Config( + region_name=self.region_name, + read_timeout=60, + retries={'max_attempts': 3, 'mode': 'adaptive'}, + ) + return boto3.client('athena', config=config) + + def iter_range_jobs(self) -> Iterator[RangeJob]: + client = self._make_client() + query_execution_id = run_athena_query( + client, self.query, self.database, self.s3_output_location, self.max_wait_time + ) + try: + for job in iter_range_jobs(client, query_execution_id, self.warc_download_prefix): + yield job + finally: + report_query_cost(client, query_execution_id) + + +def run_athena_query(client, query: str, database: str, s3_output_location: str, max_wait_time: int = 300) -> str: + """Start an Athena query and block until it completes; return the execution id. + + Raises if the query does not reach the SUCCEEDED state. If the wait is + interrupted (Ctrl-C / cancellation) or times out, the query is cancelled + server-side so we don't keep paying for a scan whose results we'll never read.""" + logger.info('Executing Athena query: %s', query) + + response = client.start_query_execution( + QueryString=query, + QueryExecutionContext={'Database': database}, + ResultConfiguration={'OutputLocation': s3_output_location}, + ) + + query_execution_id = response['QueryExecutionId'] + logger.info('Query execution started. ID: %s', query_execution_id) + + try: + status = _wait_for_query_completion(client, query_execution_id, max_wait_time) + except BaseException: + # Ctrl-C, asyncio cancellation, or timeout: stop the query server-side + # to bound the Athena scan cost, then propagate the original exception. + _stop_query(client, query_execution_id) + raise + + if status != 'SUCCEEDED': + raise Exception(f'Query failed with status: {status}') + + return query_execution_id + + +def _stop_query(client, query_execution_id: str) -> None: + """Best-effort cancellation of a running Athena query.""" + logger.warning('Cancelling Athena query %s ...', query_execution_id) + try: + client.stop_query_execution(QueryExecutionId=query_execution_id) + logger.warning('Athena query %s cancelled', query_execution_id) + except Exception as e: # pragma: no cover - best-effort cleanup + logger.warning('Failed to cancel Athena query %s: %r', query_execution_id, e) + + +def report_query_cost(client, query_execution_id: str) -> None: + """Log the bytes scanned and an estimated USD cost for a completed query.""" + try: + response = client.get_query_execution(QueryExecutionId=query_execution_id) + scanned = response['QueryExecution'].get('Statistics', {}).get('DataScannedInBytes') + except Exception as e: # pragma: no cover - best-effort reporting + logger.debug('unable to read Athena query statistics: %r', e) + return + + if scanned is None: + return + + gb = scanned / 1e9 + usd = scanned / 1e12 * ATHENA_USD_PER_TB + logger.info('Athena scanned %.2f GB, estimated cost ~$%.4f', gb, usd) + + +def _wait_for_query_completion(client, query_execution_id: str, max_wait_time: int) -> str: + """Wait for query to complete and return final status""" + start_time = time.time() + + while time.time() - start_time < max_wait_time: + response = client.get_query_execution(QueryExecutionId=query_execution_id) + + status = response['QueryExecution']['Status']['State'] + logger.info(f'Query status: {status}') + + if status in ['SUCCEEDED', 'FAILED', 'CANCELLED']: + if status == 'FAILED': + error_reason = response['QueryExecution']['Status'].get('StateChangeReason', 'Unknown error') + logger.info(f'Query failed: {error_reason}') + return status + + time.sleep(2) + + raise TimeoutError(f'Query did not complete within {max_wait_time} seconds') + + +def iter_range_jobs(client, query_execution_id: str, warc_download_prefix: Optional[str]) -> Iterable[RangeJob]: + """Retrieve query results and convert each row to a RangeJob""" + paginator = client.get_paginator('get_query_results') + page_iterator = paginator.paginate(QueryExecutionId=query_execution_id) + column_names = None + + for page in page_iterator: + rows = page['ResultSet']['Rows'] + + # Get column names from first page + if column_names is None and rows: + column_names = [col['VarCharValue'] for col in rows[0]['Data']] + validate_result_columns(column_names) + rows = rows[1:] # Skip header row + + # Process data rows + for row in rows: + row_data = [cell.get('VarCharValue', None) for cell in row['Data']] + row = dict(zip(column_names, row_data)) + + warc_filename = row['warc_filename'] + warc_url = join_warc_url(warc_download_prefix, warc_filename) + + yield RangeJob( + url=warc_url, + offset=int(row['warc_record_offset']), + length=int(row['warc_record_length']), + filename=warc_filename, + ) + + +def get_databases(client) -> list: + """Get list of available databases""" + response = client.list_databases(CatalogName='AwsDataCatalog') + return [db['Name'] for db in response['DatabaseList']] diff --git a/cdx_toolkit/filter_warc/sources/base.py b/cdx_toolkit/filter_warc/sources/base.py new file mode 100644 index 0000000..f833cce --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/base.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod +from typing import Iterator, NamedTuple, Optional + +from cdx_toolkit.filter_warc.data_classes import RangeJob + + +class CostEstimate(NamedTuple): + """Describes the scan a source is about to run, for the cost-confirmation guard. + + n_crawls is the number of crawls the scan is bounded to, or None when the scan + is unbounded / its pruning cannot be verified (e.g. raw SQL, all-crawls glob).""" + + n_crawls: Optional[int] + engine: str + + +class RangeJobSource(ABC): + """A source of RangeJobs for the repackage pipeline. + + A source owns its own stage-1 resource (Athena client, DuckDB connection, or an + fsspec file handle) and yields RangeJobs synchronously; the pipeline orchestrator + bridges the sync generator into the async fetch/write stages, and owns queueing, + the record limit, counting, and stop-sentinel emission.""" + + def estimate_cost(self) -> Optional[CostEstimate]: + """Return a CostEstimate for the cost guard, or None for sources that never + incur a per-scan charge (cdx files, csv).""" + return None + + @abstractmethod + def iter_range_jobs(self) -> Iterator[RangeJob]: + """Yield RangeJobs. Implementations open their own client/connection lazily + and close it in a finally. Each RangeJob carries `url` (authoritative for + fetching) and, where known, the relative `filename`.""" + raise NotImplementedError diff --git a/cdx_toolkit/filter_warc/sources/cdx.py b/cdx_toolkit/filter_warc/sources/cdx.py new file mode 100644 index 0000000..bb9356d --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/cdx.py @@ -0,0 +1,28 @@ +import logging +from typing import Iterator, List + +from cdx_toolkit.filter_warc.cdx_utils import iter_cdx_index_from_path +from cdx_toolkit.filter_warc.data_classes import RangeJob +from cdx_toolkit.filter_warc.sources.base import RangeJobSource + + +logger = logging.getLogger(__name__) + + +class CdxSource(RangeJobSource): + """RangeJobs read from one or more CDX index files (local or remote via fsspec).""" + + def __init__(self, cdx_paths: List[str], warc_download_prefix: str): + self.cdx_paths = cdx_paths + self.warc_download_prefix = warc_download_prefix + + def iter_range_jobs(self) -> Iterator[RangeJob]: + for index_path in self.cdx_paths: + try: + for warc_url, offset, length, filename in iter_cdx_index_from_path( + index_path, self.warc_download_prefix + ): + yield RangeJob(url=warc_url, offset=offset, length=length, filename=filename) + except Exception as e: + # Preserve the previous behaviour of skipping a bad index file. + logger.error('Failed to read CDX index from %s: %s', index_path, e) diff --git a/cdx_toolkit/filter_warc/sources/csv.py b/cdx_toolkit/filter_warc/sources/csv.py new file mode 100644 index 0000000..f9ac5aa --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/csv.py @@ -0,0 +1,90 @@ +import csv +import logging +from typing import Iterator, Optional + +import fsspec + +from cdx_toolkit.filter_warc.data_classes import RangeJob +from cdx_toolkit.filter_warc.sources.base import RangeJobSource +from cdx_toolkit.filter_warc.sources.sql_base import join_warc_url + + +logger = logging.getLogger(__name__) + +FILENAME_FIELDS = ['filename', 'offset', 'length'] +URL_FIELDS = ['url', 'offset', 'length'] + + +class RangeJobCsvWriter: + """Write RangeJobs to a CSV (local or remote via fsspec). + + Default mode writes the relative `filename` column (the consumer prepends the + WARC download prefix); `self_contained` mode writes the full `url` column.""" + + def __init__(self, path: str, self_contained: bool = False): + self.path = path + self.self_contained = self_contained + self._fields = URL_FIELDS if self_contained else FILENAME_FIELDS + self._ctx = fsspec.open(path, 'wt', newline='') + self._fh = self._ctx.__enter__() + self._writer = csv.DictWriter(self._fh, fieldnames=self._fields) + self._writer.writeheader() + + def write(self, job: RangeJob) -> None: + if self.self_contained: + row = {'url': job.url, 'offset': job.offset, 'length': job.length} + else: + if job.filename is None: + raise ValueError( + 'cannot write a non-self-contained range-jobs CSV: RangeJob.filename is ' + 'missing; pass --csv-self-contained to write full URLs instead' + ) + row = {'filename': job.filename, 'offset': job.offset, 'length': job.length} + self._writer.writerow(row) + + def close(self) -> None: + if self._ctx is not None: + self._ctx.__exit__(None, None, None) + self._ctx = None + self._fh = None + + +class CsvSource(RangeJobSource): + """RangeJobs read from a CSV/TSV. + + Mode is auto-detected from the header: a `url` column => self-contained (used + as-is); a `filename` column => the WARC download prefix is prepended. TSV is + detected from a `.tsv`/`.tsv.gz` extension; `.gz` inputs are decompressed.""" + + def __init__(self, path: str, warc_download_prefix: Optional[str]): + self.path = path + self.warc_download_prefix = warc_download_prefix + + def iter_range_jobs(self) -> Iterator[RangeJob]: + path = str(self.path) + delimiter = '\t' if path.endswith(('.tsv', '.tsv.gz')) else ',' + compression = 'gzip' if path.endswith('.gz') else None + + with fsspec.open(self.path, 'rt', newline='', compression=compression) as fh: + reader = csv.DictReader(fh, delimiter=delimiter) + fields = set(reader.fieldnames or []) + if 'url' in fields: + mode_url = True + elif 'filename' in fields: + mode_url = False + else: + raise ValueError( + f'range-jobs CSV {self.path} must have a `url` or `filename` column ' + f'(got header: {reader.fieldnames})' + ) + + for row in reader: + offset = int(row['offset']) + length = int(row['length']) + if mode_url: + url = row['url'] + filename = None + else: + filename = row['filename'] + url = join_warc_url(self.warc_download_prefix, filename) + yield RangeJob(url=url, offset=offset, length=length, filename=filename) diff --git a/cdx_toolkit/filter_warc/sources/duckdb.py b/cdx_toolkit/filter_warc/sources/duckdb.py new file mode 100644 index 0000000..f109cd9 --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/duckdb.py @@ -0,0 +1,102 @@ +import logging +from typing import Iterator, List, Optional + +from cdx_toolkit.filter_warc.data_classes import RangeJob +from cdx_toolkit.filter_warc.sources.base import RangeJobSource, CostEstimate +from cdx_toolkit.filter_warc.sources.sql_base import build_sql, validate_result_columns, join_warc_url + + +logger = logging.getLogger(__name__) + +try: + import duckdb + _HAS_DUCKDB = True +except ImportError: # pragma: no cover - exercised in minimal installs + duckdb = None + _HAS_DUCKDB = False + + +def _build_from_clause(index_path: str, crawls: Optional[List[str]]) -> str: + """Build a read_parquet(...) FROM expression over the CC columnar index. + + When crawls are given we glob only those crawl partitions (the pruning lever); + otherwise we glob every crawl (expensive -> the cost guard fires).""" + base = index_path.rstrip('/') + if crawls: + globs = [f"'{base}/crawl={c}/subset=warc/*.parquet'" for c in crawls] + else: + globs = [f"'{base}/crawl=*/subset=warc/*.parquet'"] + return f"read_parquet([{', '.join(globs)}], hive_partitioning=true)" + + +class DuckDbSource(RangeJobSource): + """RangeJobs from a query against the CC columnar index via DuckDB (read_parquet on S3). + + Reads the public CommonCrawl parquet directly; AWS region/credentials come from + the environment (the public bucket is readable with valid credentials).""" + + def __init__( + self, + *, + query: Optional[str] = None, + hostnames: Optional[List[str]] = None, + crawls: Optional[List[str]] = None, + index_path: str, + warc_download_prefix: Optional[str], + limit: int = 0, + region_name: str = 'us-east-1', + ): + self.raw_query = query + self.hostnames = hostnames + self.crawls = crawls + self.index_path = index_path + self.warc_download_prefix = warc_download_prefix + self.limit = limit + self.region_name = region_name + + def estimate_cost(self) -> CostEstimate: + if self.raw_query is not None: + return CostEstimate(n_crawls=None, engine='duckdb') + return CostEstimate(n_crawls=len(self.crawls) if self.crawls else None, engine='duckdb') + + def _build_query(self) -> str: + if self.raw_query is not None: + return self.raw_query + from_clause = _build_from_clause(self.index_path, self.crawls) + # crawl pruning is done in the FROM glob, so no crawl IN (...) in the WHERE + return build_sql(from_clause, self.hostnames, crawls=None, limit=self.limit) + + def iter_range_jobs(self) -> Iterator[RangeJob]: + if not _HAS_DUCKDB: + raise RuntimeError( + 'DuckDB engine requires optional dependencies. Install cdx_toolkit[duckdb].' + ) + + query = self._build_query() + logger.info('Executing DuckDB query: %s', query) + + con = duckdb.connect() + try: + con.execute('INSTALL httpfs; LOAD httpfs;') + con.execute(f"SET s3_region='{self.region_name}';") + cur = con.execute(query) + + col_names = [d[0] for d in cur.description] + validate_result_columns(col_names) + idx = {name: i for i, name in enumerate(col_names)} + + while True: + rows = cur.fetchmany(1000) + if not rows: + break + for row in rows: + warc_filename = row[idx['warc_filename']] + warc_url = join_warc_url(self.warc_download_prefix, warc_filename) + yield RangeJob( + url=warc_url, + offset=int(row[idx['warc_record_offset']]), + length=int(row[idx['warc_record_length']]), + filename=warc_filename, + ) + finally: + con.close() diff --git a/cdx_toolkit/filter_warc/sources/factory.py b/cdx_toolkit/filter_warc/sources/factory.py new file mode 100644 index 0000000..269ef3f --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/factory.py @@ -0,0 +1,94 @@ +import logging +from typing import Optional, Tuple + +from cdx_toolkit.filter_warc.cdx_utils import get_cdx_paths +from cdx_toolkit.filter_warc.sources.base import RangeJobSource +from cdx_toolkit.filter_warc.sources.sql_base import build_athena_query, resolve_crawl_names + + +logger = logging.getLogger(__name__) + + +def make_source(args, *, warc_download_prefix: Optional[str], record_limit: int) -> RangeJobSource: + """Build the RangeJobSource selected by --target-source (+ --engine for sql). + + Centralises all source/engine validation (engine required iff sql; + hostnames/query/query-file mutual exclusivity; required connection options).""" + target = args.target_source + + if target == 'cdx': + from cdx_toolkit.filter_warc.sources.cdx import CdxSource + cdx_paths = get_cdx_paths(args.cdx_path, args.cdx_glob) + return CdxSource(cdx_paths, warc_download_prefix) + + if target == 'csv': + from cdx_toolkit.filter_warc.sources.csv import CsvSource + if not args.csv_path: + raise ValueError('--csv-path is required for --target-source csv') + return CsvSource(args.csv_path, warc_download_prefix) + + if target == 'sql': + return _make_sql_source(args, warc_download_prefix, record_limit) + + raise ValueError(f'Invalid target source: {target} (available: cdx, sql, csv)') + + +def _resolve_sql_query_spec(args) -> Tuple[Optional[str], Optional[list], Optional[list]]: + """Validate the query-defining flags and return (raw_sql, hostnames, crawls). + + Exactly one of {--hostnames} / {--query|--query-file} must be given. For the + guided (hostnames) path, --crawl is resolved to concrete crawl names.""" + raw_sql = args.query + if args.query_file: + if raw_sql: + raise ValueError('--query and --query-file are mutually exclusive') + with open(args.query_file) as f: + raw_sql = f.read() + + if raw_sql and args.hostnames: + raise ValueError('--query/--query-file are mutually exclusive with --hostnames') + if not raw_sql and not args.hostnames: + raise ValueError('the sql target requires either --hostnames or --query/--query-file') + + if raw_sql: + return raw_sql, None, None + + crawls = resolve_crawl_names(args.crawl) if args.crawl else None + return None, args.hostnames, crawls + + +def _make_sql_source(args, warc_download_prefix, record_limit) -> RangeJobSource: + engine = args.engine + if not engine: + raise ValueError('--engine is required for --target-source sql (choices: athena, duckdb)') + + raw_sql, hostnames, crawls = _resolve_sql_query_spec(args) + limit = 0 if record_limit is None else record_limit + + if engine == 'athena': + from cdx_toolkit.filter_warc.sources.athena import AthenaSource + if not args.athena_s3_output: + raise ValueError('--athena-s3-output is required for --engine athena') + database = args.athena_database or 'ccindex' + query = raw_sql if raw_sql else build_athena_query(hostnames, crawls=crawls, limit=limit) + n_crawls = None if raw_sql else (len(crawls) if crawls else None) + return AthenaSource( + query=query, + database=database, + s3_output_location=args.athena_s3_output, + warc_download_prefix=warc_download_prefix, + n_crawls=n_crawls, + ) + + if engine == 'duckdb': + from cdx_toolkit.filter_warc.sources.duckdb import DuckDbSource + return DuckDbSource( + query=raw_sql, + hostnames=hostnames, + crawls=crawls, + index_path=args.duckdb_index_path, + warc_download_prefix=warc_download_prefix, + limit=limit, + ) + + raise ValueError(f'Invalid --engine: {engine} (choices: athena, duckdb)') diff --git a/cdx_toolkit/filter_warc/sources/sql_base.py b/cdx_toolkit/filter_warc/sources/sql_base.py new file mode 100644 index 0000000..c6f527e --- /dev/null +++ b/cdx_toolkit/filter_warc/sources/sql_base.py @@ -0,0 +1,133 @@ +import logging +import re +from typing import List, Optional + +from cdx_toolkit.commoncrawl import normalize_crawl, get_cc_endpoints, match_cc_crawls + + +logger = logging.getLogger(__name__) + +# Required output columns that any index query (built or raw) must provide. +REQUIRED_RESULT_COLUMNS = ('warc_filename', 'warc_record_offset', 'warc_record_length') + +# Hostnames, TLDs and crawl names only ever contain these characters. Validating +# against this set both prevents SQL injection and catches malformed input early. +_SQL_LITERAL_RE = re.compile(r'^[A-Za-z0-9.\-]+$') + +# Default Common Crawl index mirror used to resolve --crawl to concrete crawl names. +CC_INDEX_MIRROR = 'https://index.commoncrawl.org/' + + +def escape_sql_literal(value: str) -> str: + """Validate and quote a value for safe inclusion in a SQL string literal. + + Only hostname/TLD/crawl-name characters (letters, digits, dot, hyphen) are + allowed, so the result cannot break out of the quotes or inject SQL.""" + if not isinstance(value, str) or not _SQL_LITERAL_RE.match(value): + raise ValueError( + f'Invalid value for SQL query literal: {value!r} ' + '(allowed characters: letters, digits, dot, hyphen)' + ) + return "'" + value + "'" + + +def build_where_sql(url_host_names: List[str], crawls: Optional[List[str]] = None) -> str: + """Build the WHERE body (without the `WHERE` keyword) shared by all SQL engines. + + If `crawls` is a non-empty list of crawl names (e.g. ['CC-MAIN-2025-33']), a + `crawl IN (...)` partition filter is added -- the main lever for reducing scan + cost. Engines differ only in their FROM clause (see build_sql).""" + if not url_host_names: + raise ValueError('an index query requires at least one hostname') + + tlds = sorted({h.split('.')[-1] for h in url_host_names}) + query_tlds = ' OR '.join(f'url_host_tld = {escape_sql_literal(t)}' for t in tlds) + query_hosts = ' OR '.join(f'url_host_name = {escape_sql_literal(h)}' for h in url_host_names) + + clauses = [ + "subset = 'warc'", + f'({query_tlds}) -- help the query optimizer', + f'({query_hosts})', + ] + # TODO wire --from/--to into a fetch_time BETWEEN ... clause here + if crawls: + crawl_in = ', '.join(escape_sql_literal(c) for c in crawls) + clauses.append(f'crawl IN ({crawl_in})') + + return '\n AND '.join(clauses) + + +def build_sql( + from_clause: str, + url_host_names: List[str], + crawls: Optional[List[str]] = None, + limit: int = 0, +) -> str: + """Assemble a full SELECT for the columnar index. + + `from_clause` is the text following FROM (e.g. `ccindex` for Athena, or a + `read_parquet(...)` expression for DuckDB).""" + where_sql = build_where_sql(url_host_names, crawls) + limit_sql = f'\n LIMIT {limit}' if limit and limit > 0 else '' + + return f""" + SELECT + warc_filename, warc_record_offset, warc_record_length + FROM {from_clause} + WHERE {where_sql}{limit_sql}""" + + +def build_athena_query( + url_host_names: List[str], + crawls: Optional[List[str]] = None, + limit: int = 0, + table: str = 'ccindex', +) -> str: + """Athena flavour of build_sql (FROM ). Kept for the athena_job_generator shim.""" + return build_sql(table, url_host_names, crawls=crawls, limit=limit) + + +def validate_result_columns(column_names) -> None: + """Raise a clear error if the query result lacks the required columns.""" + missing = [c for c in REQUIRED_RESULT_COLUMNS if c not in (column_names or [])] + if missing: + raise ValueError( + 'Index query result is missing required columns: ' + ', '.join(missing) + + '. The query (including a raw --query) must SELECT ' + + ', '.join(REQUIRED_RESULT_COLUMNS) + '.' + ) + + +def join_warc_url(prefix: Optional[str], warc_filename: str) -> str: + """Join a download prefix with a warc_filename robustly. + + - if warc_filename is already absolute (contains '://'), return it unchanged + (supports custom queries / self-contained CSVs whose value is a full URL); + - if prefix is empty/None, return warc_filename unchanged; + - otherwise join with exactly one '/' (no double slash, no missing slash).""" + if '://' in warc_filename: + return warc_filename + if not prefix: + return warc_filename + return prefix.rstrip('/') + '/' + warc_filename.lstrip('/') + + +def endpoint_to_crawl_name(endpoint: str) -> str: + """Turn a collinfo cdx-api endpoint into its crawl name. + + e.g. 'https://index.commoncrawl.org/CC-MAIN-2025-33-index' -> 'CC-MAIN-2025-33'.""" + name = endpoint.rstrip('/').split('/')[-1] + if name.endswith('-index'): + name = name[:-len('-index')] + return name + + +def resolve_crawl_names(crawl_arg) -> List[str]: + """Resolve a --crawl value to concrete CC-MAIN crawl names for the partition filter. + + Reuses the CDX path's helpers so `--crawl` accepts the same forms (comma-separated + names or an integer for the most recent N crawls).""" + crawls = normalize_crawl([crawl_arg]) + raw_index_list = get_cc_endpoints(CC_INDEX_MIRROR) + matched = match_cc_crawls(crawls, raw_index_list) + return [endpoint_to_crawl_name(ep) for ep in matched] diff --git a/cdx_toolkit/filter_warc/warc_filter.py b/cdx_toolkit/filter_warc/warc_filter.py index ed1e0b1..8399286 100644 --- a/cdx_toolkit/filter_warc/warc_filter.py +++ b/cdx_toolkit/filter_warc/warc_filter.py @@ -2,20 +2,18 @@ import logging import statistics import sys -from typing import List, Literal, Optional, Dict +from typing import List, Optional, Dict from botocore.config import Config -from cdx_toolkit.filter_warc.athena_job_generator import get_range_jobs_from_athena from cdx_toolkit.filter_warc.s3_utils import ( is_s3_url, ) from cdx_toolkit.filter_warc.data_classes import RangeJob, RangePayload, ThroughputTracker from cdx_toolkit.filter_warc.warc_utils import create_new_writer_with_header -from cdx_toolkit.filter_warc.cdx_utils import ( - iter_cdx_index_from_path, -) +from cdx_toolkit.filter_warc.sources.base import RangeJobSource +from cdx_toolkit.filter_warc.sources.csv import RangeJobCsvWriter from cdx_toolkit.filter_warc.warc_utils import get_bytes_from_warc_record, get_metadata_record_from_path @@ -23,8 +21,6 @@ logger = logging.getLogger(__name__) -TargetSourceType = Literal['cdx', 'athena'] - class WARCFilter: """Filter or extract specific records from WARC files based on CDX indexes. @@ -47,11 +43,10 @@ def __init__( self, prefix_path: str, writer_info: Dict, - target_source: TargetSourceType = 'cdx', - cdx_paths: Optional[List[str]] = None, - athena_database: Optional[str] = None, - athena_query: Optional[str] = None, - athena_s3_output_location: Optional[str] = None, + source: RangeJobSource, + range_jobs_output: Optional[str] = None, + no_fetch: bool = False, + csv_self_contained: bool = False, writer_subprefix: Optional[str] = None, write_paths_as_metadata_records: Optional[List[str]] = None, record_limit: int = 0, @@ -75,11 +70,13 @@ def __init__( """Initialize the WARC filter. Args: - target_source: Source of filter targets (Athena query or CDX files). - cdx_paths: List of paths to CDX index files. - athena_database: Database for Athena query. - athena_query: Prepared Athena SQL string to execute (built by the caller). - athena_s3_output_location: S3 output location for Athena query. + source: RangeJobSource that yields the WARC ranges to repackage. + range_jobs_output: Optional path; if set, each generated RangeJob is + written to this CSV (materialization). + no_fetch: If True, only generate range jobs (and write range_jobs_output); + skip fetching/writing WARC records entirely. + csv_self_contained: If True, range_jobs_output stores full URLs instead of + relative filenames. prefix_path: Output path prefix for filtered WARC files. writer_info: Dictionary containing writer metadata. writer_subprefix: Optional subprefix for writer output paths. @@ -102,11 +99,10 @@ def __init__( min_part_size: Minimum part byte size for multipart uploads (default: 5 MiB). max_file_size: Maximum byte size for individual WARC files (default: 1 GiB). """ - self.cdx_paths = cdx_paths - self.target_source: TargetSourceType = target_source - self.athena_database = athena_database - self.athena_s3_output_location = athena_s3_output_location - self.athena_query = athena_query + self.source = source + self.range_jobs_output = range_jobs_output + self.no_fetch = no_fetch + self.csv_self_contained = csv_self_contained self.prefix_path = prefix_path self.writer_info = writer_info self.writer_subprefix = writer_subprefix @@ -131,7 +127,6 @@ def __init__( else max(int(self.num_readers / self.fetcher_to_consumer_ratio), 1) ) - # self.gzip = self.cdx_paths[0].endswith('.gz') if self.cdx_paths else False self.gzip = True self.warc_version = warc_version @@ -153,17 +148,15 @@ def filter(self) -> int: return -1 def needs_aws(self) -> bool: - """Returns true if AWS (S3/Athena) is needed at any stage. + """Returns true if the read/write (stage 2/3) S3 clients are needed. - Returns: - bool: True if AWS client is needed for any operation. + Sources own their own stage-1 resource (Athena client / DuckDB connection / + fsspec), so this only concerns WARC reads and output writes. With no_fetch + there are no reads/writes at all. """ - return ( - self.target_source == 'athena' # stage 1 - or (self.cdx_paths is not None and len(self.cdx_paths) > 0 and is_s3_url(self.cdx_paths[0])) # stage 1 - or is_s3_url(self.warc_download_prefix) # stage 3 - or is_s3_url(self.prefix_path) # stage 3 - ) + if self.no_fetch: + return False + return is_s3_url(self.warc_download_prefix) or is_s3_url(self.prefix_path) def get_boto3_base_config(self) -> Dict: """Get boto3 base configuration for AWS client. @@ -184,10 +177,10 @@ def get_boto3_base_config(self) -> Dict: ) async def get_aws_clients(self) -> Optional[Dict]: - """Return S3/Athena clients for job/read/write if needed. + """Return async S3 clients for WARC reads/writes if needed. - Returns: - Optional[aioboto3.Session.client]: S3/Athena client context manager if S3/Athena is needed, None otherwise. + Stage-1 clients/connections are owned by the source, so this only builds the + read/write S3 clients used to fetch WARC ranges and write output. Raises: SystemExit: If S3 is needed but Python version is < 3.9. @@ -198,23 +191,9 @@ async def get_aws_clients(self) -> Optional[Dict]: sys.exit(1) import aioboto3 - import boto3 session = aioboto3.Session() - # Lightweight config for CDX index reads - job_config = Config( - max_pool_connections=5, - read_timeout=60, - **self.get_boto3_base_config(), - ) - - if self.target_source == 'athena': - # Athena does not need an async client - job_client = boto3.client('athena', config=job_config) - else: - job_client = session.client('s3', config=job_config) - # High-throughput config for range reads read_config = Config( max_pool_connections=self.num_readers * 3, @@ -232,7 +211,6 @@ async def get_aws_clients(self) -> Optional[Dict]: ) return { - 'job': job_client, 'read': session.client('s3', config=read_config), 'write': session.client('s3', config=write_config), } @@ -245,86 +223,99 @@ async def filter_async(self) -> int: Returns: int: Number of records written. """ + # Materialize-only: just drain the source into the range-jobs CSV. + if self.no_fetch: + return await self._run_materialize_only() + range_jobs_queue: asyncio.Queue = asyncio.Queue(maxsize=self.range_jobs_queue_size) warc_records_queue: asyncio.Queue = asyncio.Queue(maxsize=self.warc_records_queue_size) if self.needs_aws(): clients = await self.get_aws_clients() - - # Handle mixed async/sync clients - Athena client is sync, S3 clients are async - if self.target_source == 'athena': - job_aws_client = clients['job'] # Sync client, no context manager needed - async with clients['read'] as read_aws_client, clients['write'] as write_aws_client: - return await self._run_filter_pipeline( - range_jobs_queue=range_jobs_queue, - warc_records_queue=warc_records_queue, - job_aws_client=job_aws_client, - read_s3_client=read_aws_client, - write_s3_client=write_aws_client, - ) - else: - async with clients['job'] as job_aws_client, clients['read'] as read_aws_client, clients[ - 'write' - ] as write_aws_client: - return await self._run_filter_pipeline( - range_jobs_queue=range_jobs_queue, - warc_records_queue=warc_records_queue, - job_aws_client=job_aws_client, - read_s3_client=read_aws_client, - write_s3_client=write_aws_client, - ) + async with clients['read'] as read_aws_client, clients['write'] as write_aws_client: + return await self._run_filter_pipeline( + range_jobs_queue=range_jobs_queue, + warc_records_queue=warc_records_queue, + read_s3_client=read_aws_client, + write_s3_client=write_aws_client, + ) else: return await self._run_filter_pipeline( range_jobs_queue=range_jobs_queue, warc_records_queue=warc_records_queue, ) + def _make_csv_writer(self) -> Optional[RangeJobCsvWriter]: + if self.range_jobs_output is None: + return None + return RangeJobCsvWriter(self.range_jobs_output, self_contained=self.csv_self_contained) + + async def _produce_range_jobs(self, range_jobs_queue: Optional[asyncio.Queue], csv_writer) -> int: + """Drive the (sync) source in a worker thread, feeding the async queue. + + Owns counting, the record limit, and (when a queue is present) emitting one + _STOP sentinel per reader in a finally -- so readers never hang even if the + source raises mid-iteration.""" + loop = asyncio.get_running_loop() + count = 0 + + def drain() -> int: + nonlocal count + for job in self.source.iter_range_jobs(): + if csv_writer is not None: + csv_writer.write(job) + if range_jobs_queue is not None: + asyncio.run_coroutine_threadsafe(range_jobs_queue.put(job), loop).result() + count += 1 + if self.record_limit and count >= self.record_limit: + logger.warning('Limit reached at %i', count) + break + return count + + try: + await asyncio.to_thread(drain) + finally: + if csv_writer is not None: + csv_writer.close() + if range_jobs_queue is not None: + for _ in range(self.num_readers): + await range_jobs_queue.put(_STOP) + + logger.info('Generated %d range jobs', count) + return count + + async def _run_materialize_only(self) -> int: + """--no-fetch: generate range jobs and write only the range-jobs CSV.""" + csv_writer = self._make_csv_writer() + if csv_writer is None: + logger.warning('--no-fetch set without --range-jobs-output: nothing to do') + count = await self._produce_range_jobs(range_jobs_queue=None, csv_writer=csv_writer) + logger.info('Materialized %d range jobs (no WARC fetch)', count) + return count + async def _run_filter_pipeline( self, range_jobs_queue: asyncio.Queue, warc_records_queue: asyncio.Queue, - job_aws_client=None, read_s3_client=None, write_s3_client=None, ) -> int: """Run the actual filter pipeline with or without S3 client. Args: - range_jobs_queue: Queue for range jobs from CDX index. + range_jobs_queue: Queue for range jobs from the source. warc_records_queue: Queue for WARC record payloads. - job_aws_client: Optional AWS (S3/Athena) client for jobs generation. read_s3_client: Optional S3 client for reads from S3. write_s3_client: Optional S3 client for writes S3. Returns: int: Number of records written. """ - # Fetch file paths and ranges (offset, length) from index files logger.info('Starting job generator, %d WARC readers, %d WARC writers', self.num_readers, self.num_writers) - # Generate range jobs from different target sources - if self.target_source == 'cdx': - job_generators = asyncio.create_task( - self.generate_range_jobs_from_cdx( - range_jobs_queue, - s3_client=job_aws_client, - ) - ) - elif self.target_source == 'athena': - job_generators = asyncio.create_task( - get_range_jobs_from_athena( - client=job_aws_client, - query=self.athena_query, - database=self.athena_database, - s3_output_location=self.athena_s3_output_location, - job_queue=range_jobs_queue, - queue_stop_object=_STOP, - warc_download_prefix=self.warc_download_prefix, - num_fetchers=self.num_readers, - ) - ) - else: - raise ValueError(f'Invalid target source: {self.target_source}') + # Generate range jobs from the configured source (bridged sync->async in a thread). + csv_writer = self._make_csv_writer() + job_generators = asyncio.create_task(self._produce_range_jobs(range_jobs_queue, csv_writer)) # Read WARC records based on file paths and ranges warc_readers = [ @@ -414,66 +405,6 @@ async def _coordinate_writer_shutdown(self, warc_readers: List[asyncio.Task], wa for _ in range(self.num_writers): await warc_records_queue.put(_STOP) - async def generate_range_jobs_from_single_cdx( - self, - cdx_path: str, - range_jobs_queue: asyncio.Queue, - count: int = 0, - ) -> int: - """Read a CDX file and generate range jobs based on URLs and offsets.""" - for warc_url, offset, length in iter_cdx_index_from_path( - cdx_path, warc_download_prefix=self.warc_download_prefix - ): - # Convert the CDX record back to a RangeJob - job = RangeJob(url=warc_url, offset=offset, length=length, records_count=1) - await range_jobs_queue.put(job) - count += 1 - - if self.record_limit > 0 and count >= self.record_limit: - logger.warning('Index limit reached at %i', count) - break - - return count - - async def generate_range_jobs_from_cdx( - self, - range_jobs_queue: asyncio.Queue, - s3_client=None, - ): - """Read the CDX paths, parse lines -> RangeJob (WARC files and offets) -> key_queue. - - Args: - range_jobs_queue: Queue to put RangeJob objects into. - s3_client: Optional S3 client for reading CDX indexes from S3. - """ - - logger.info('Range index limit: %i', self.record_limit) - count = 0 - - # Iterate over index files - # TODO this could be done in parallel - for index_path in self.cdx_paths: - # Fetch range queries from index - try: - count += await self.generate_range_jobs_from_single_cdx( - cdx_path=index_path, - range_jobs_queue=range_jobs_queue, - count=count, - ) - - except Exception as e: - logger.error('Failed to read CDX index from %s: %s', index_path, e) - - if self.record_limit > 0 and count >= self.record_limit: - logger.warning('Limit reached at %i', count) - break - - # signal fetchers to stop - for _ in range(self.num_readers): - await range_jobs_queue.put(_STOP) - - logger.info('Enqueued %d jobs from %s', count, index_path) - async def read_warc_records( self, reader_id: int, diff --git a/requirements.txt b/requirements.txt index 6b6efa9..5de9c65 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,9 @@ url-is-in>=0.1.1 fsspec[s3] botocore +# optional DuckDB SQL engine (install via cdx_toolkit[duckdb]) +duckdb + # used by Makefile pytest>=6.2.4 pytest-cov>=2.12.1 diff --git a/setup.py b/setup.py index 15ab048..4398197 100755 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ test_requirements = ['pytest', 'pytest-cov', 'flake8', 'responses'] optional_s3_requirements = ['fsspec[s3]', 'botocore'] +optional_duckdb_requirements = ['duckdb'] package_requirements = ['twine', 'setuptools', 'setuptools-scm'] @@ -19,10 +20,14 @@ extras_require = { 's3': optional_s3_requirements, + 'duckdb': optional_duckdb_requirements, 'test': test_requirements, # setup no longer tests, so make them an extra 'package': package_requirements, 'dev': package_requirements, - 'all': test_requirements + package_requirements + dev_requirements + optional_s3_requirements, + 'all': ( + test_requirements + package_requirements + dev_requirements + + optional_s3_requirements + optional_duckdb_requirements + ), } scripts = ['scripts/cdx_size', 'scripts/cdx_iter'] diff --git a/tests/conftest.py b/tests/conftest.py index 50caca6..ce0ce91 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,6 +30,12 @@ except ImportError: # pragma: no cover - exercised in minimal installs _HAS_FSSPEC = False +try: + import duckdb # noqa: F401 + _HAS_DUCKDB = True +except ImportError: # pragma: no cover - exercised in minimal installs + _HAS_DUCKDB = False + import functools from typing import Dict, Optional import requests @@ -162,6 +168,13 @@ def requires_aws_athena(func): ) +def requires_duckdb(func): + """Pytest decorator that skips a test if the optional duckdb dependency is missing.""" + return pytest.mark.skipif( + not _HAS_DUCKDB, reason='duckdb is not installed; install cdx_toolkit[duckdb] to enable DuckDB tests.' + )(func) + + @pytest.fixture def s3_tmpdir(): """S3 equivalent of tmpdir - provides a temporary S3 path and handles cleanup.""" diff --git a/tests/filter_warc/test_athena_command_validation.py b/tests/filter_warc/test_athena_command_validation.py deleted file mode 100644 index b352b99..0000000 --- a/tests/filter_warc/test_athena_command_validation.py +++ /dev/null @@ -1,84 +0,0 @@ -from argparse import Namespace -from unittest.mock import patch - -import pytest - -from cdx_toolkit.filter_warc import command -from cdx_toolkit.filter_warc.command import resolve_athena_query - - -def make_args(**kw): - defaults = dict( - athena_query=None, - athena_query_file=None, - athena_hostnames=None, - athena_database='ccindex', - athena_s3_output='s3://commoncrawl-ci-temp/athena-results/', - crawl=None, - limit=None, - ) - defaults.update(kw) - return Namespace(**defaults) - - -def test_query_and_hostnames_mutually_exclusive(): - args = make_args(athena_query='SELECT 1', athena_hostnames=['example.com']) - with pytest.raises(ValueError): - resolve_athena_query(args) - - -def test_query_and_query_file_mutually_exclusive(tmp_path): - f = tmp_path / 'q.sql' - f.write_text('SELECT 1') - args = make_args(athena_query='SELECT 1', athena_query_file=str(f)) - with pytest.raises(ValueError): - resolve_athena_query(args) - - -def test_neither_hostnames_nor_query(): - args = make_args() - with pytest.raises(ValueError): - resolve_athena_query(args) - - -def test_missing_database(): - args = make_args(athena_hostnames=['example.com'], athena_database=None) - with pytest.raises(ValueError): - resolve_athena_query(args) - - -def test_missing_s3_output(): - args = make_args(athena_hostnames=['example.com'], athena_s3_output=None) - with pytest.raises(ValueError): - resolve_athena_query(args) - - -def test_raw_query_is_unbounded(): - args = make_args(athena_query='SELECT warc_filename FROM x') - sql, n_crawls = resolve_athena_query(args) - assert sql == 'SELECT warc_filename FROM x' - assert n_crawls is None - - -def test_query_file_is_read(tmp_path): - f = tmp_path / 'q.sql' - f.write_text('SELECT warc_filename, warc_record_offset, warc_record_length FROM x') - args = make_args(athena_query_file=str(f)) - sql, n_crawls = resolve_athena_query(args) - assert 'warc_filename' in sql - assert n_crawls is None - - -def test_built_no_crawl_is_unbounded(): - args = make_args(athena_hostnames=['example.com']) - sql, n_crawls = resolve_athena_query(args) - assert 'example.com' in sql - assert n_crawls is None - - -def test_built_with_crawls_counts(): - args = make_args(athena_hostnames=['example.com'], crawl='CC-MAIN-2025-33,CC-MAIN-2025-30') - with patch.object(command, '_resolve_crawl_names', return_value=['CC-MAIN-2025-33', 'CC-MAIN-2025-30']): - sql, n_crawls = resolve_athena_query(args) - assert n_crawls == 2 - assert 'crawl IN' in sql diff --git a/tests/filter_warc/test_athena_job_generator.py b/tests/filter_warc/test_athena_job_generator.py deleted file mode 100644 index 05ec0b1..0000000 --- a/tests/filter_warc/test_athena_job_generator.py +++ /dev/null @@ -1,74 +0,0 @@ -import asyncio -from cdx_toolkit.filter_warc.warc_filter import _STOP -from cdx_toolkit.filter_warc.athena_job_generator import ( - get_databases, - get_range_jobs_from_athena, - build_athena_query, -) -from tests.conftest import TEST_ATHENA_DATABASE, TEST_ATHENA_S3_LOCATION, requires_aws_athena - -import boto3 - - -@requires_aws_athena -def test_get_databases(): - from botocore.config import Config - import boto3 - - boto_cfg = Config( - region_name='us-east-1', - ) - athena_client = boto3.client('athena', config=boto_cfg) - dbs = get_databases(client=athena_client) - assert 'ccindex' in dbs - - -@requires_aws_athena -def test_get_range_jobs_from_athena(): - async def run_test(): - # Setup test data - warc_download_prefix = 's3://commoncrawl' - - # Create asyncio queues - key_queue = asyncio.Queue() - - # Setup S3 client - from botocore.config import Config - - boto_cfg = Config( - region_name='us-east-1', - retries={'max_attempts': 3, 'mode': 'standard'}, - connect_timeout=10, - read_timeout=120, - ) - - athena_client = boto3.client('athena', config=boto_cfg) - - # Build the query and generate range jobs from Athena - query = build_athena_query( - ['oceancolor.sci.gsfc.nasa.gov'], - limit=10, # Use 10 records to ensure we have enough data - ) - await get_range_jobs_from_athena( - client=athena_client, - query=query, - database=TEST_ATHENA_DATABASE, - s3_output_location=TEST_ATHENA_S3_LOCATION, - job_queue=key_queue, - warc_download_prefix=warc_download_prefix, - num_fetchers=1, - queue_stop_object=_STOP, - ) - - # Collect all range jobs - range_jobs = [] - while not key_queue.empty(): - job = await key_queue.get() - if job is not _STOP: - range_jobs.append(job) - key_queue.task_done() - - assert len(range_jobs) == 10, "Invalid range jobs count" - - # Run the async test - asyncio.run(run_test()) diff --git a/tests/filter_warc/test_athena_prompt.py b/tests/filter_warc/test_athena_prompt.py index 6d36958..0b03b54 100644 --- a/tests/filter_warc/test_athena_prompt.py +++ b/tests/filter_warc/test_athena_prompt.py @@ -2,19 +2,30 @@ import pytest -from cdx_toolkit.filter_warc.command import confirm_athena_cost +from cdx_toolkit.filter_warc.command import confirm_cost +from cdx_toolkit.filter_warc.sources.base import CostEstimate + + +def est(n_crawls): + return CostEstimate(n_crawls=n_crawls, engine='athena') + + +def test_none_estimate_never_prompts(): + # cdx/csv sources return None -> no cost prompt + with patch('builtins.input') as inp: + confirm_cost(None, confirmed=False) + inp.assert_not_called() def test_small_crawl_set_no_prompt(): - # <= LARGE_CRAWL_SET_THRESHOLD -> considered cheap, no prompt with patch('builtins.input') as inp: - confirm_athena_cost(n_crawls=5, confirmed=False) + confirm_cost(est(5), confirmed=False) inp.assert_not_called() def test_confirmed_flag_bypasses_prompt(): with patch('builtins.input') as inp: - confirm_athena_cost(n_crawls=None, confirmed=True) + confirm_cost(est(None), confirmed=True) inp.assert_not_called() @@ -22,31 +33,31 @@ def test_large_crawl_set_non_tty_aborts(): with patch('cdx_toolkit.filter_warc.command.sys.stdin') as stdin: stdin.isatty.return_value = False with pytest.raises(SystemExit): - confirm_athena_cost(n_crawls=11, confirmed=False) + confirm_cost(est(11), confirmed=False) def test_unknown_crawls_non_tty_aborts(): with patch('cdx_toolkit.filter_warc.command.sys.stdin') as stdin: stdin.isatty.return_value = False with pytest.raises(SystemExit): - confirm_athena_cost(n_crawls=None, confirmed=False) + confirm_cost(est(None), confirmed=False) def test_tty_yes_proceeds(): with patch('cdx_toolkit.filter_warc.command.sys.stdin') as stdin, patch('builtins.input', return_value='y'): stdin.isatty.return_value = True - confirm_athena_cost(n_crawls=None, confirmed=False) # should not raise + confirm_cost(est(None), confirmed=False) # should not raise def test_tty_empty_answer_aborts(): with patch('cdx_toolkit.filter_warc.command.sys.stdin') as stdin, patch('builtins.input', return_value=''): stdin.isatty.return_value = True with pytest.raises(SystemExit): - confirm_athena_cost(n_crawls=None, confirmed=False) + confirm_cost(est(None), confirmed=False) def test_tty_no_answer_aborts(): with patch('cdx_toolkit.filter_warc.command.sys.stdin') as stdin, patch('builtins.input', return_value='n'): stdin.isatty.return_value = True with pytest.raises(SystemExit): - confirm_athena_cost(n_crawls=11, confirmed=False) + confirm_cost(est(11), confirmed=False) diff --git a/tests/filter_warc/test_athena_query_builder.py b/tests/filter_warc/test_athena_query_builder.py index 8e2cf97..8b2be57 100644 --- a/tests/filter_warc/test_athena_query_builder.py +++ b/tests/filter_warc/test_athena_query_builder.py @@ -1,12 +1,12 @@ import pytest -from cdx_toolkit.filter_warc.athena_job_generator import ( +from cdx_toolkit.filter_warc.sources.sql_base import ( build_athena_query, escape_sql_literal, validate_result_columns, join_warc_url, - run_athena_query, ) +from cdx_toolkit.filter_warc.sources.athena import run_athena_query class _FakeAthenaClient: diff --git a/tests/filter_warc/test_cdx_utils.py b/tests/filter_warc/test_cdx_utils.py index 5ca7035..c6b2f8f 100644 --- a/tests/filter_warc/test_cdx_utils.py +++ b/tests/filter_warc/test_cdx_utils.py @@ -67,9 +67,9 @@ def mock_read_cdx_line(line, warc_download_prefix): # Should have 3 valid results despite 2 invalid lines being skipped assert len(results) == 3 - # Verify the valid results - assert results[0] == ('http://warc-prefix/test.warc.gz', 100, 500) - assert results[1] == ('http://warc-prefix/test2.warc.gz', 600, 300) - assert results[2] == ('http://warc-prefix/test3.warc.gz', 900, 200) + # Verify the valid results (url, offset, length, filename) + assert results[0] == ('http://warc-prefix/test.warc.gz', 100, 500, 'test.warc.gz') + assert results[1] == ('http://warc-prefix/test2.warc.gz', 600, 300, 'test2.warc.gz') + assert results[2] == ('http://warc-prefix/test3.warc.gz', 900, 200, 'test3.warc.gz') finally: os.unlink(tmp_file_path) diff --git a/tests/filter_warc/test_command.py b/tests/filter_warc/test_command.py index 49ac875..4da1233 100644 --- a/tests/filter_warc/test_command.py +++ b/tests/filter_warc/test_command.py @@ -34,7 +34,7 @@ def assert_cli_warc_by_cdx( args=[ '-v', '--limit=10', - 'warc_by_cdx', + 'repackage', f'--cdx-path={str(index_path)}', '--write-paths-as-metadata-records', str(metadata_record_path), @@ -119,6 +119,121 @@ def test_cli_warc_by_cdx_over_http_in_parallel(tmpdir, caplog): ) +def _produce_range_jobs_csv(tmpdir, csv_name, self_contained=False): + """Run `repackage` over the CDX fixture with --no-fetch to materialize a range-jobs CSV.""" + import csv as _csv + + index_path = fixture_path / 'filtered_CC-MAIN-2024-30_cdx-00187.gz' + csv_path = os.path.join(str(tmpdir), csv_name) + + args = [ + '--limit=10', + 'repackage', + '--target-source=cdx', + f'--cdx-path={str(index_path)}', + f'--range-jobs-output={csv_path}', + '--no-fetch', + ] + if self_contained: + args.append('--csv-self-contained') + main(args=args) + + with open(csv_path, newline='') as f: + rows = list(_csv.DictReader(f)) + return csv_path, rows + + +def _assert_repackaged_warc(warc_path, metadata_record_path): + """Inspect a repackaged WARC and assert the expected fixture content.""" + response_records = [] + response_contents = [] + metadata_record = None + metadata_record_headers = None + + with fsspec.open(warc_path, 'rb') as stream: + for record in ArchiveIterator(stream): + if record.rec_type == 'response': + response_records.append(record) + response_contents.append(record.content_stream().read().decode('utf-8', errors='ignore')) + if record.rec_type == 'metadata': + metadata_record = record + metadata_record_headers = record.rec_headers + + assert len(response_records) == 10, 'Invalid record count' + assert 'Catalogue en ligne Mission de France' in response_contents[0], 'Invalid response content' + assert 'dojo/dijit/themes/tundra/tundra' in response_contents[9], 'Invalid response content' + assert metadata_record is not None, 'Metadata record not set' + assert metadata_record_headers.get('WARC-Payload-Digest') == 'sha1:VXA2A5YUS3TAY36AUO6MACRMNOH5RXG2', ( + 'Invalid metadata block digest' + ) + + +def test_repackage_csv_materialize_filename(tmpdir): + """--no-fetch produces a filename-based range-jobs CSV without fetching WARCs.""" + csv_path, rows = _produce_range_jobs_csv(tmpdir, 'ranges.csv') + assert set(rows[0].keys()) == {'filename', 'offset', 'length'} + assert len(rows) == 10 + # No WARC was written for the default --prefix + assert not any(name.endswith('.warc.gz') for name in os.listdir(str(tmpdir))) + + +def test_repackage_csv_materialize_self_contained(tmpdir): + """--csv-self-contained produces a url-based range-jobs CSV.""" + csv_path, rows = _produce_range_jobs_csv(tmpdir, 'ranges_url.csv', self_contained=True) + assert set(rows[0].keys()) == {'url', 'offset', 'length'} + assert len(rows) == 10 + assert rows[0]['url'].startswith('https://data.commoncrawl.org/') + + +def test_cli_repackage_csv_roundtrip(tmpdir): + """End-to-end: produce a filename-based ranges CSV, then consume it and fetch over HTTP.""" + metadata_record_path = TEST_DATA_PATH / 'filter_cdx/whitelist_10_urls.txt' + csv_path, rows = _produce_range_jobs_csv(tmpdir, 'ranges.csv') + + base_prefix = str(tmpdir) + main( + args=[ + '-v', + 'repackage', + '--target-source=csv', + f'--csv-path={csv_path}', + '--write-paths-as-metadata-records', + str(metadata_record_path), + f'--prefix={base_prefix}/TEST_warc_by_index', + '--creator=foo', + '--operator=bob', + '--warc-download-prefix=https://data.commoncrawl.org', + ] + ) + + warc_path = os.path.join(base_prefix, 'TEST_warc_by_index-000000-001.warc.gz') + _assert_repackaged_warc(warc_path, metadata_record_path) + + +def test_cli_repackage_csv_roundtrip_self_contained(tmpdir): + """End-to-end with self-contained URLs: header auto-detected on read; no prefix needed.""" + metadata_record_path = TEST_DATA_PATH / 'filter_cdx/whitelist_10_urls.txt' + csv_path, rows = _produce_range_jobs_csv(tmpdir, 'ranges_url.csv', self_contained=True) + + base_prefix = str(tmpdir) + main( + args=[ + '-v', + 'repackage', + '--target-source=csv', + f'--csv-path={csv_path}', + '--write-paths-as-metadata-records', + str(metadata_record_path), + f'--prefix={base_prefix}/TEST_warc_by_index', + '--creator=foo', + '--operator=bob', + ] + ) + + warc_path = os.path.join(base_prefix, 'TEST_warc_by_index-000000-001.warc.gz') + _assert_repackaged_warc(warc_path, metadata_record_path) + + @requires_aws_s3 def test_cli_warc_by_cdx_over_s3(tmpdir, caplog): assert_cli_warc_by_cdx('s3://commoncrawl', base_prefix=tmpdir, caplog=caplog) @@ -182,7 +297,7 @@ def test_warc_by_cdx_no_index_files_found_exits(tmpdir, caplog): main( args=[ '-v', - 'warc_by_cdx', + 'repackage', f'--cdx-path={str(tmpdir)}', f'--prefix={str(tmpdir)}/TEST', '--cdx-glob=/nonexistent-pattern-*.gz', @@ -201,7 +316,7 @@ def test_warc_by_cdx_subprefix_and_metadata(tmpdir): args=[ '-v', '--limit=1', - 'warc_by_cdx', + 'repackage', f'--cdx-path={str(index_path)}', f'--prefix={str(tmpdir)}/TEST', '--subprefix=SUB', @@ -235,7 +350,7 @@ def test_warc_by_cdx_without_creator_operator(tmpdir): args=[ '-v', '--limit=1', - 'warc_by_cdx', + 'repackage', f'--cdx-path={str(index_path)}', f'--prefix={str(tmpdir)}/TEST_NO_META', ] @@ -276,14 +391,15 @@ def test_cli_warc_by_athena( args=[ '-v', '--limit=10', - 'warc_by_cdx', - '--target-source=athena', + 'repackage', + '--target-source=sql', + '--engine=athena', '--athena-database=ccindex', '--athena-s3-output=s3://commoncrawl-ci-temp/athena-results/', - '--athena-hostnames', + '--hostnames', 'oceancolor.sci.gsfc.nasa.gov', 'example.com', - '--confirm-athena-cost', + '--confirm-cost', f'--prefix={base_prefix}/TEST_warc_by_index', '--creator=foo', '--operator=bob', diff --git a/tests/filter_warc/test_csv_source.py b/tests/filter_warc/test_csv_source.py new file mode 100644 index 0000000..d6a6a84 --- /dev/null +++ b/tests/filter_warc/test_csv_source.py @@ -0,0 +1,66 @@ +import csv + +import pytest + +from cdx_toolkit.filter_warc.data_classes import RangeJob +from cdx_toolkit.filter_warc.sources.csv import RangeJobCsvWriter, CsvSource + + +def test_writer_filename_mode_and_read(tmp_path): + path = str(tmp_path / 'ranges.csv') + w = RangeJobCsvWriter(path, self_contained=False) + w.write(RangeJob( + url='https://data.commoncrawl.org/crawl-data/x.warc.gz', + offset=10, length=20, filename='crawl-data/x.warc.gz', + )) + w.close() + + with open(path, newline='') as f: + rows = list(csv.DictReader(f)) + assert rows[0] == {'filename': 'crawl-data/x.warc.gz', 'offset': '10', 'length': '20'} + + jobs = list(CsvSource(path, 'https://data.commoncrawl.org').iter_range_jobs()) + assert jobs[0].url == 'https://data.commoncrawl.org/crawl-data/x.warc.gz' + assert jobs[0].offset == 10 and jobs[0].length == 20 + assert jobs[0].filename == 'crawl-data/x.warc.gz' + + +def test_writer_self_contained_mode_and_read(tmp_path): + path = str(tmp_path / 'ranges_url.csv') + w = RangeJobCsvWriter(path, self_contained=True) + w.write(RangeJob(url='s3://commoncrawl/crawl-data/x.warc.gz', offset=5, length=7, filename='crawl-data/x.warc.gz')) + w.close() + + with open(path, newline='') as f: + rows = list(csv.DictReader(f)) + assert rows[0] == {'url': 's3://commoncrawl/crawl-data/x.warc.gz', 'offset': '5', 'length': '7'} + + # url used as-is regardless of the (ignored) prefix + jobs = list(CsvSource(path, 'https://ignored').iter_range_jobs()) + assert jobs[0].url == 's3://commoncrawl/crawl-data/x.warc.gz' + assert jobs[0].filename is None + + +def test_writer_filename_mode_requires_filename(tmp_path): + w = RangeJobCsvWriter(str(tmp_path / 'r.csv'), self_contained=False) + with pytest.raises(ValueError): + w.write(RangeJob(url='https://x/y.warc.gz', offset=1, length=2, filename=None)) + w.close() + + +def test_csv_source_missing_columns(tmp_path): + path = str(tmp_path / 'bad.csv') + with open(path, 'w') as f: + f.write('foo,bar\n1,2\n') + with pytest.raises(ValueError): + list(CsvSource(path, 'https://x').iter_range_jobs()) + + +def test_csv_source_tsv(tmp_path): + path = str(tmp_path / 'ranges.tsv') + with open(path, 'w') as f: + f.write('filename\toffset\tlength\n') + f.write('crawl-data/x.warc.gz\t10\t20\n') + jobs = list(CsvSource(path, 'https://data.commoncrawl.org').iter_range_jobs()) + assert jobs[0].url == 'https://data.commoncrawl.org/crawl-data/x.warc.gz' + assert jobs[0].offset == 10 and jobs[0].length == 20 diff --git a/tests/filter_warc/test_grouped_range_jobs.py b/tests/filter_warc/test_grouped_range_jobs.py index b046709..91ccbaf 100644 --- a/tests/filter_warc/test_grouped_range_jobs.py +++ b/tests/filter_warc/test_grouped_range_jobs.py @@ -5,7 +5,7 @@ def test_iter_cdx_index_from_test_data(): cdx_path = TEST_DATA_PATH / 'warc_by_cdx/filtered_CC-MAIN-2024-30_cdx-00187.gz' results = list(iter_cdx_index_from_path(str(cdx_path), 'http://warc-prefix')) - # [(url, offset, length)] + # [(url, offset, length, filename)] # sort results by offsets results.sort(key=lambda x: x[1]) @@ -20,8 +20,8 @@ def group_neighbor_chunks(items): current_chunk = [items[0]] for i in range(1, len(items)): - prev_url, prev_offset, prev_length = items[i - 1] - curr_url, curr_offset, curr_length = items[i] + prev_url, prev_offset, prev_length = items[i - 1][:3] + curr_url, curr_offset, curr_length = items[i][:3] # Check if current item is a neighbor (same URL and contiguous) if curr_url == prev_url and curr_offset == prev_offset + prev_length + 4: diff --git a/tests/filter_warc/test_make_source.py b/tests/filter_warc/test_make_source.py new file mode 100644 index 0000000..d9a8508 --- /dev/null +++ b/tests/filter_warc/test_make_source.py @@ -0,0 +1,132 @@ +from argparse import Namespace +from unittest.mock import patch + +import pytest + +from cdx_toolkit.filter_warc.sources import make_source +from cdx_toolkit.filter_warc.sources.athena import AthenaSource +from cdx_toolkit.filter_warc.sources.cdx import CdxSource +from cdx_toolkit.filter_warc.sources.csv import CsvSource +from cdx_toolkit.filter_warc.sources.duckdb import DuckDbSource + + +def make_args(**kw): + defaults = dict( + target_source='cdx', + engine=None, + hostnames=None, + query=None, + query_file=None, + athena_database=None, + athena_s3_output='s3://commoncrawl-ci-temp/athena-results/', + duckdb_index_path='s3://commoncrawl/cc-index/table/cc-main/warc/', + csv_path=None, + cdx_path=None, + cdx_glob=None, + crawl=None, + ) + defaults.update(kw) + return Namespace(**defaults) + + +def build(**kw): + return make_source(make_args(**kw), warc_download_prefix='https://data.commoncrawl.org', record_limit=0) + + +# --- cdx / csv --- + +def test_cdx_source(): + src = build(target_source='cdx', cdx_path='/tmp/index.cdx.gz') + assert isinstance(src, CdxSource) + assert src.estimate_cost() is None + + +def test_csv_requires_path(): + with pytest.raises(ValueError): + build(target_source='csv') + + +def test_csv_source(): + src = build(target_source='csv', csv_path='/tmp/ranges.csv') + assert isinstance(src, CsvSource) + assert src.estimate_cost() is None + + +# --- sql: engine + query validation --- + +def test_sql_requires_engine(): + with pytest.raises(ValueError): + build(target_source='sql', hostnames=['example.com']) + + +def test_sql_hostnames_and_query_mutually_exclusive(): + with pytest.raises(ValueError): + build(target_source='sql', engine='athena', hostnames=['example.com'], query='SELECT 1') + + +def test_sql_query_and_query_file_mutually_exclusive(tmp_path): + f = tmp_path / 'q.sql' + f.write_text('SELECT 1') + with pytest.raises(ValueError): + build(target_source='sql', engine='athena', query='SELECT 1', query_file=str(f)) + + +def test_sql_neither_hostnames_nor_query(): + with pytest.raises(ValueError): + build(target_source='sql', engine='athena') + + +# --- athena --- + +def test_athena_requires_s3_output(): + with pytest.raises(ValueError): + build(target_source='sql', engine='athena', hostnames=['example.com'], athena_s3_output=None) + + +def test_athena_built_no_crawl_unbounded(): + src = build(target_source='sql', engine='athena', hostnames=['example.com']) + assert isinstance(src, AthenaSource) + est = src.estimate_cost() + assert est.engine == 'athena' and est.n_crawls is None + assert 'ccindex' in src.query and 'example.com' in src.query + + +def test_athena_raw_query_unbounded(): + src = build(target_source='sql', engine='athena', query='SELECT warc_filename FROM x') + assert src.query == 'SELECT warc_filename FROM x' + assert src.estimate_cost().n_crawls is None + + +def test_athena_with_crawls_counts(): + with patch( + 'cdx_toolkit.filter_warc.sources.factory.resolve_crawl_names', + return_value=['CC-MAIN-2025-33', 'CC-MAIN-2025-30'], + ): + src = build( + target_source='sql', engine='athena', hostnames=['example.com'], + crawl='CC-MAIN-2025-33,CC-MAIN-2025-30', + ) + assert src.estimate_cost().n_crawls == 2 + assert 'crawl IN' in src.query + + +# --- duckdb --- + +def test_duckdb_built_query_has_read_parquet_and_partition(): + with patch( + 'cdx_toolkit.filter_warc.sources.factory.resolve_crawl_names', + return_value=['CC-MAIN-2026-17'], + ): + src = build(target_source='sql', engine='duckdb', hostnames=['commoncrawl.org'], crawl='CC-MAIN-2026-17') + assert isinstance(src, DuckDbSource) + assert src.estimate_cost().n_crawls == 1 + query = src._build_query() + assert 'read_parquet' in query + assert 'crawl=CC-MAIN-2026-17' in query + assert 'commoncrawl.org' in query + + +def test_duckdb_no_crawl_unbounded(): + src = build(target_source='sql', engine='duckdb', hostnames=['commoncrawl.org']) + assert src.estimate_cost().n_crawls is None + assert 'crawl=*' in src._build_query() diff --git a/tests/filter_warc/test_producer.py b/tests/filter_warc/test_producer.py new file mode 100644 index 0000000..ba6efcd --- /dev/null +++ b/tests/filter_warc/test_producer.py @@ -0,0 +1,68 @@ +import asyncio +import csv + +from cdx_toolkit.filter_warc.data_classes import RangeJob +from cdx_toolkit.filter_warc.sources.base import RangeJobSource +from cdx_toolkit.filter_warc.warc_filter import WARCFilter, _STOP + + +class FakeSource(RangeJobSource): + def __init__(self, jobs, raise_after=None): + self.jobs = jobs + self.raise_after = raise_after + + def iter_range_jobs(self): + for i, job in enumerate(self.jobs): + if self.raise_after is not None and i == self.raise_after: + raise RuntimeError('boom') + yield job + + +def _jobs(n): + return [ + RangeJob(url=f'https://data.commoncrawl.org/{i}.warc.gz', offset=i, length=1, filename=f'{i}.warc.gz') + for i in range(n) + ] + + +def test_no_fetch_materializes_csv(tmp_path): + out = str(tmp_path / 'ranges.csv') + wf = WARCFilter( + source=FakeSource(_jobs(3)), + prefix_path=str(tmp_path / 'out'), + writer_info={'isPartOf': 'test'}, + range_jobs_output=out, + no_fetch=True, + ) + n = wf.filter() + assert n == 3 + with open(out, newline='') as f: + rows = list(csv.DictReader(f)) + assert len(rows) == 3 + assert set(rows[0].keys()) == {'filename', 'offset', 'length'} + + +def test_producer_emits_stops_even_when_source_raises(tmp_path): + """Regression: a source that raises mid-iteration must still release the readers.""" + wf = WARCFilter( + source=FakeSource(_jobs(2), raise_after=1), + prefix_path=str(tmp_path), + writer_info={}, + n_parallel=3, + ) + + async def run(): + queue: asyncio.Queue = asyncio.Queue() + try: + await wf._produce_range_jobs(queue, None) + except RuntimeError: + pass + items = [] + while not queue.empty(): + items.append(await queue.get()) + return items + + items = asyncio.run(run()) + stops = sum(1 for it in items if it is _STOP) + assert wf.num_readers == 3 + assert stops == 3 diff --git a/tests/filter_warc/test_sql_sources_gated.py b/tests/filter_warc/test_sql_sources_gated.py new file mode 100644 index 0000000..f90990b --- /dev/null +++ b/tests/filter_warc/test_sql_sources_gated.py @@ -0,0 +1,86 @@ +"""Gated end-to-end tests for the SQL sources (Athena, DuckDB). + +These query only a single crawl partition for a single host (cheap, partition-pruned +-- never an all-crawls scan), materialize a range-jobs CSV with --no-fetch, then +consume it and fetch the WARCs over HTTP. They require AWS credentials (and duckdb) +and are skipped in CI. +""" +import csv +import os + +import fsspec +from warcio.archiveiterator import ArchiveIterator + +from cdx_toolkit.cli import main +from tests.conftest import requires_aws_athena, requires_aws_s3, requires_duckdb, TEST_ATHENA_S3_LOCATION + + +CRAWL = 'CC-MAIN-2026-17' +HOST = 'commoncrawl.org' + + +def _produce_and_consume(tmpdir, produce_args): + csv_path = os.path.join(str(tmpdir), 'ranges.csv') + + # Produce: run the (cheap, single-crawl) SQL query, write range jobs, no WARC fetch. + main(args=produce_args + [f'--range-jobs-output={csv_path}', '--no-fetch']) + + with open(csv_path, newline='') as f: + rows = list(csv.DictReader(f)) + assert len(rows) > 0, 'expected at least one successful warc fetch for the host/crawl' + + # Consume: read the CSV, fetch the WARC records over HTTP, write a new WARC. + base_prefix = str(tmpdir) + main( + args=[ + 'repackage', + '--target-source=csv', + f'--csv-path={csv_path}', + f'--prefix={base_prefix}/TEST_sql', + '--warc-download-prefix=https://data.commoncrawl.org', + ] + ) + + warc_path = os.path.join(base_prefix, 'TEST_sql-000000-001.warc.gz') + response_count = 0 + with fsspec.open(warc_path, 'rb') as stream: + for record in ArchiveIterator(stream): + if record.rec_type == 'response': + response_count += 1 + target = record.rec_headers.get_header('WARC-Target-URI') or '' + assert HOST in target, f'unexpected target URI: {target}' + + assert response_count == len(rows), 'every range job should yield a response record' + + +@requires_aws_athena +def test_repackage_sql_athena_e2e(tmpdir): + _produce_and_consume( + tmpdir, + [ + '--crawl', CRAWL, + 'repackage', + '--target-source=sql', + '--engine=athena', + '--athena-database=ccindex', + f'--athena-s3-output={TEST_ATHENA_S3_LOCATION}', + '--hostnames', HOST, + '--confirm-cost', + ], + ) + + +@requires_aws_s3 +@requires_duckdb +def test_repackage_sql_duckdb_e2e(tmpdir): + _produce_and_consume( + tmpdir, + [ + '--crawl', CRAWL, + 'repackage', + '--target-source=sql', + '--engine=duckdb', + '--hostnames', HOST, + '--confirm-cost', + ], + ) diff --git a/tests/filter_warc/test_warc_filter.py b/tests/filter_warc/test_warc_filter.py index 7d8307b..b157f61 100644 --- a/tests/filter_warc/test_warc_filter.py +++ b/tests/filter_warc/test_warc_filter.py @@ -5,9 +5,13 @@ from tests.conftest import TEST_DATA_PATH from cdx_toolkit.filter_warc.warc_filter import WARCFilter +from cdx_toolkit.filter_warc.sources.cdx import CdxSource fixture_path = TEST_DATA_PATH / 'warc_by_cdx' +# A throwaway source for unit tests that only exercise reader/writer/rotate/log methods. +_FAKE_SOURCE = CdxSource(['/fake/path'], 'https://data.commoncrawl.org') + def test_filter_keyboard_interrupt_handling(caplog): """Test that KeyboardInterrupt is properly handled in the filter method.""" @@ -16,7 +20,7 @@ def test_filter_keyboard_interrupt_handling(caplog): # Set log level to capture WARNING messages caplog.set_level(logging.WARNING, logger='cdx_toolkit.filter_warc.warc_filter') - warc_filter = WARCFilter(cdx_paths=['/fake/path'], prefix_path='/fake/prefix', writer_info={'writer_id': 1}) + warc_filter = WARCFilter(source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}) # Mock filter_async to raise KeyboardInterrupt with patch.object(warc_filter, 'filter_async', side_effect=KeyboardInterrupt('Simulated user interrupt')): @@ -35,7 +39,7 @@ def test_rotate_files_no_rotation_needed(): async def run_test(): warc_filter = WARCFilter( - cdx_paths=['/fake/path'], + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000, # 1KB limit @@ -76,7 +80,7 @@ def test_rotate_files_rotation_needed_without_resource_records(): async def run_test(): warc_filter = WARCFilter( - cdx_paths=['/fake/path'], + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000, # 1KB limit @@ -134,7 +138,7 @@ def test_rotate_files_rotation_needed_with_metadata_records(): async def run_test(): warc_filter = WARCFilter( - cdx_paths=['/fake/path'], + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000, # 1KB limit @@ -189,7 +193,7 @@ def test_rotate_files_no_max_file_size_set(): async def run_test(): warc_filter = WARCFilter( - cdx_paths=['/fake/path'], + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=None, # No limit @@ -230,7 +234,7 @@ def test_rotate_files_edge_case_exact_limit(): async def run_test(): warc_filter = WARCFilter( - cdx_paths=['/fake/path'], + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000, # 1KB limit @@ -271,7 +275,7 @@ def test_rotate_files_edge_case_just_over_limit(): async def run_test(): warc_filter = WARCFilter( - cdx_paths=['/fake/path'], + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000, # 1KB limit @@ -317,7 +321,7 @@ def test_rotate_files_kwargs_passed_through(): async def run_test(): warc_filter = WARCFilter( - cdx_paths=['/fake/path'], prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000 + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000 ) mock_writer = AsyncMock() @@ -370,7 +374,7 @@ async def run_test(): caplog.set_level(logging.INFO, logger='cdx_toolkit.filter_warc.warc_filter') warc_filter = WARCFilter( - cdx_paths=['/fake/path'], prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000 + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, max_file_size=1000 ) mock_writer = AsyncMock() @@ -402,9 +406,11 @@ async def run_test(): def test_log_writer(caplog): """Test log writer.""" + import logging + caplog.set_level(logging.INFO, logger='cdx_toolkit.filter_warc.warc_filter') warc_filter = WARCFilter( - cdx_paths=['/fake/path'], + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, log_every_n=2, @@ -419,9 +425,11 @@ def test_log_writer(caplog): def test_log_reader(caplog): """Test log reader.""" + import logging + caplog.set_level(logging.INFO, logger='cdx_toolkit.filter_warc.warc_filter') warc_filter = WARCFilter( - cdx_paths=['/fake/path'], + source=_FAKE_SOURCE, prefix_path='/fake/prefix', writer_info={'writer_id': 1}, log_every_n=2, From 2f06c610891c5342494ff1dcd73d155996fd0c4c Mon Sep 17 00:00:00 2001 From: malteos Date: Fri, 12 Jun 2026 13:16:16 +0200 Subject: [PATCH 2/6] feat: add --domains (url_host_registered_domain) guided filter Extend the guided SQL filter (athena + duckdb) to match on url_host_registered_domain in addition to url_host_name. --hostnames and --domains can be combined (OR-ed); at least one (or a raw --query) is required. TLD optimizer hint is derived from both. Also make the DuckDB source resilient to transient S3 read timeouts (http_timeout/http_retries). Tests: unit coverage for domain-only / combined host+domain query building (athena + duckdb) and factory validation; gated DuckDB domain e2e (CC-MAIN-2026-17 / commoncrawl.org, --limit 10) verified live. --- cdx_toolkit/filter_warc/args.py | 16 ++++++-- cdx_toolkit/filter_warc/sources/duckdb.py | 13 +++++- cdx_toolkit/filter_warc/sources/factory.py | 29 +++++++------ cdx_toolkit/filter_warc/sources/sql_base.py | 41 +++++++++++++------ .../filter_warc/test_athena_query_builder.py | 20 +++++++++ tests/filter_warc/test_make_source.py | 31 +++++++++++++- tests/filter_warc/test_sql_sources_gated.py | 19 +++++++++ 7 files changed, 140 insertions(+), 29 deletions(-) diff --git a/cdx_toolkit/filter_warc/args.py b/cdx_toolkit/filter_warc/args.py index d238134..71afa90 100644 --- a/cdx_toolkit/filter_warc/args.py +++ b/cdx_toolkit/filter_warc/args.py @@ -32,9 +32,19 @@ def add_repackage_args(parser: argparse.ArgumentParser): type=str, nargs='+', default=None, - help=('Hostnames to filter for (whitelist) via the SQL index. Use this OR ' - '--query/--query-file (mutually exclusive). Combine with the global --crawl to ' - 'restrict the scan to specific crawls (strongly recommended for cost).'), + help=('Exact hostnames (url_host_name, e.g. www.example.com) to filter for via the SQL ' + 'index. Combine with --domains; mutually exclusive with --query/--query-file. ' + 'Combine with the global --crawl to restrict the scan to specific crawls ' + '(strongly recommended for cost).'), + ) + parser.add_argument( + '--domains', + type=str, + nargs='+', + default=None, + help=('Registered domains (url_host_registered_domain, e.g. example.com) to filter for via ' + 'the SQL index; also matches subdomains. Combine with --hostnames; mutually exclusive ' + 'with --query/--query-file.'), ) parser.add_argument( '--query', diff --git a/cdx_toolkit/filter_warc/sources/duckdb.py b/cdx_toolkit/filter_warc/sources/duckdb.py index f109cd9..2513bd2 100644 --- a/cdx_toolkit/filter_warc/sources/duckdb.py +++ b/cdx_toolkit/filter_warc/sources/duckdb.py @@ -40,6 +40,7 @@ def __init__( *, query: Optional[str] = None, hostnames: Optional[List[str]] = None, + domains: Optional[List[str]] = None, crawls: Optional[List[str]] = None, index_path: str, warc_download_prefix: Optional[str], @@ -48,6 +49,7 @@ def __init__( ): self.raw_query = query self.hostnames = hostnames + self.domains = domains self.crawls = crawls self.index_path = index_path self.warc_download_prefix = warc_download_prefix @@ -64,7 +66,10 @@ def _build_query(self) -> str: return self.raw_query from_clause = _build_from_clause(self.index_path, self.crawls) # crawl pruning is done in the FROM glob, so no crawl IN (...) in the WHERE - return build_sql(from_clause, self.hostnames, crawls=None, limit=self.limit) + return build_sql( + from_clause, self.hostnames, crawls=None, limit=self.limit, + url_host_registered_domains=self.domains, + ) def iter_range_jobs(self) -> Iterator[RangeJob]: if not _HAS_DUCKDB: @@ -79,6 +84,12 @@ def iter_range_jobs(self) -> Iterator[RangeJob]: try: con.execute('INSTALL httpfs; LOAD httpfs;') con.execute(f"SET s3_region='{self.region_name}';") + # Be resilient to transient S3 read timeouts on large parquet partitions. + for stmt in ('SET http_timeout=120000;', 'SET http_retries=5;'): + try: + con.execute(stmt) + except Exception as e: # pragma: no cover - setting unsupported on this duckdb + logger.debug('duckdb setting skipped: %s (%r)', stmt, e) cur = con.execute(query) col_names = [d[0] for d in cur.description] diff --git a/cdx_toolkit/filter_warc/sources/factory.py b/cdx_toolkit/filter_warc/sources/factory.py index 269ef3f..ba0e098 100644 --- a/cdx_toolkit/filter_warc/sources/factory.py +++ b/cdx_toolkit/filter_warc/sources/factory.py @@ -33,11 +33,12 @@ def make_source(args, *, warc_download_prefix: Optional[str], record_limit: int) raise ValueError(f'Invalid target source: {target} (available: cdx, sql, csv)') -def _resolve_sql_query_spec(args) -> Tuple[Optional[str], Optional[list], Optional[list]]: - """Validate the query-defining flags and return (raw_sql, hostnames, crawls). +def _resolve_sql_query_spec(args) -> Tuple[Optional[str], Optional[list], Optional[list], Optional[list]]: + """Validate the query-defining flags and return (raw_sql, hostnames, domains, crawls). - Exactly one of {--hostnames} / {--query|--query-file} must be given. For the - guided (hostnames) path, --crawl is resolved to concrete crawl names.""" + The guided path (--hostnames and/or --domains) and a raw query + (--query/--query-file) are mutually exclusive. For the guided path, --crawl is + resolved to concrete crawl names.""" raw_sql = args.query if args.query_file: if raw_sql: @@ -45,16 +46,17 @@ def _resolve_sql_query_spec(args) -> Tuple[Optional[str], Optional[list], Option with open(args.query_file) as f: raw_sql = f.read() - if raw_sql and args.hostnames: - raise ValueError('--query/--query-file are mutually exclusive with --hostnames') - if not raw_sql and not args.hostnames: - raise ValueError('the sql target requires either --hostnames or --query/--query-file') + has_guided = bool(args.hostnames) or bool(args.domains) + if raw_sql and has_guided: + raise ValueError('--query/--query-file are mutually exclusive with --hostnames/--domains') + if not raw_sql and not has_guided: + raise ValueError('the sql target requires --hostnames, --domains, or --query/--query-file') if raw_sql: - return raw_sql, None, None + return raw_sql, None, None, None crawls = resolve_crawl_names(args.crawl) if args.crawl else None - return None, args.hostnames, crawls + return None, args.hostnames, args.domains, crawls def _make_sql_source(args, warc_download_prefix, record_limit) -> RangeJobSource: @@ -62,7 +64,7 @@ def _make_sql_source(args, warc_download_prefix, record_limit) -> RangeJobSource if not engine: raise ValueError('--engine is required for --target-source sql (choices: athena, duckdb)') - raw_sql, hostnames, crawls = _resolve_sql_query_spec(args) + raw_sql, hostnames, domains, crawls = _resolve_sql_query_spec(args) limit = 0 if record_limit is None else record_limit if engine == 'athena': @@ -70,7 +72,9 @@ def _make_sql_source(args, warc_download_prefix, record_limit) -> RangeJobSource if not args.athena_s3_output: raise ValueError('--athena-s3-output is required for --engine athena') database = args.athena_database or 'ccindex' - query = raw_sql if raw_sql else build_athena_query(hostnames, crawls=crawls, limit=limit) + query = raw_sql if raw_sql else build_athena_query( + hostnames, crawls=crawls, limit=limit, url_host_registered_domains=domains, + ) n_crawls = None if raw_sql else (len(crawls) if crawls else None) return AthenaSource( query=query, @@ -85,6 +89,7 @@ def _make_sql_source(args, warc_download_prefix, record_limit) -> RangeJobSource return DuckDbSource( query=raw_sql, hostnames=hostnames, + domains=domains, crawls=crawls, index_path=args.duckdb_index_path, warc_download_prefix=warc_download_prefix, diff --git a/cdx_toolkit/filter_warc/sources/sql_base.py b/cdx_toolkit/filter_warc/sources/sql_base.py index c6f527e..e466682 100644 --- a/cdx_toolkit/filter_warc/sources/sql_base.py +++ b/cdx_toolkit/filter_warc/sources/sql_base.py @@ -31,18 +31,30 @@ def escape_sql_literal(value: str) -> str: return "'" + value + "'" -def build_where_sql(url_host_names: List[str], crawls: Optional[List[str]] = None) -> str: +def build_where_sql( + url_host_names: Optional[List[str]] = None, + crawls: Optional[List[str]] = None, + url_host_registered_domains: Optional[List[str]] = None, +) -> str: """Build the WHERE body (without the `WHERE` keyword) shared by all SQL engines. - If `crawls` is a non-empty list of crawl names (e.g. ['CC-MAIN-2025-33']), a - `crawl IN (...)` partition filter is added -- the main lever for reducing scan + The guided filter matches on `url_host_name` (exact host, e.g. www.example.com) + and/or `url_host_registered_domain` (e.g. example.com, which also covers its + subdomains); predicates are OR-ed together. At least one host or domain is + required. If `crawls` is a non-empty list of crawl names (e.g. ['CC-MAIN-2025-33']), + a `crawl IN (...)` partition filter is added -- the main lever for reducing scan cost. Engines differ only in their FROM clause (see build_sql).""" - if not url_host_names: - raise ValueError('an index query requires at least one hostname') + url_host_names = url_host_names or [] + domains = url_host_registered_domains or [] + if not url_host_names and not domains: + raise ValueError('an index query requires at least one hostname or registered domain') - tlds = sorted({h.split('.')[-1] for h in url_host_names}) + tlds = sorted({v.split('.')[-1] for v in (list(url_host_names) + list(domains))}) query_tlds = ' OR '.join(f'url_host_tld = {escape_sql_literal(t)}' for t in tlds) - query_hosts = ' OR '.join(f'url_host_name = {escape_sql_literal(h)}' for h in url_host_names) + + host_predicates = [f'url_host_name = {escape_sql_literal(h)}' for h in url_host_names] + host_predicates += [f'url_host_registered_domain = {escape_sql_literal(d)}' for d in domains] + query_hosts = ' OR '.join(host_predicates) clauses = [ "subset = 'warc'", @@ -59,15 +71,16 @@ def build_where_sql(url_host_names: List[str], crawls: Optional[List[str]] = Non def build_sql( from_clause: str, - url_host_names: List[str], + url_host_names: Optional[List[str]] = None, crawls: Optional[List[str]] = None, limit: int = 0, + url_host_registered_domains: Optional[List[str]] = None, ) -> str: """Assemble a full SELECT for the columnar index. `from_clause` is the text following FROM (e.g. `ccindex` for Athena, or a `read_parquet(...)` expression for DuckDB).""" - where_sql = build_where_sql(url_host_names, crawls) + where_sql = build_where_sql(url_host_names, crawls, url_host_registered_domains=url_host_registered_domains) limit_sql = f'\n LIMIT {limit}' if limit and limit > 0 else '' return f""" @@ -78,13 +91,17 @@ def build_sql( def build_athena_query( - url_host_names: List[str], + url_host_names: Optional[List[str]] = None, crawls: Optional[List[str]] = None, limit: int = 0, table: str = 'ccindex', + url_host_registered_domains: Optional[List[str]] = None, ) -> str: - """Athena flavour of build_sql (FROM
). Kept for the athena_job_generator shim.""" - return build_sql(table, url_host_names, crawls=crawls, limit=limit) + """Athena flavour of build_sql (FROM
).""" + return build_sql( + table, url_host_names, crawls=crawls, limit=limit, + url_host_registered_domains=url_host_registered_domains, + ) def validate_result_columns(column_names) -> None: diff --git a/tests/filter_warc/test_athena_query_builder.py b/tests/filter_warc/test_athena_query_builder.py index 8b2be57..304809e 100644 --- a/tests/filter_warc/test_athena_query_builder.py +++ b/tests/filter_warc/test_athena_query_builder.py @@ -43,6 +43,26 @@ def test_build_query_hostnames(): assert 'LIMIT' not in q +def test_build_query_domains_only(): + q = build_athena_query(url_host_registered_domains=['example.com']) + assert "url_host_registered_domain = 'example.com'" in q + assert "url_host_tld = 'com'" in q + assert 'url_host_name' not in q + + +def test_build_query_hostnames_and_domains(): + q = build_athena_query(['www.example.com'], url_host_registered_domains=['example.org']) + assert "url_host_name = 'www.example.com'" in q + assert "url_host_registered_domain = 'example.org'" in q + assert "url_host_tld = 'com'" in q + assert "url_host_tld = 'org'" in q + + +def test_build_query_requires_host_or_domain(): + with pytest.raises(ValueError): + build_athena_query() + + def test_build_query_with_crawls(): q = build_athena_query(['example.com'], crawls=['CC-MAIN-2025-33', 'CC-MAIN-2025-30']) assert "crawl IN ('CC-MAIN-2025-33', 'CC-MAIN-2025-30')" in q diff --git a/tests/filter_warc/test_make_source.py b/tests/filter_warc/test_make_source.py index d9a8508..ee60199 100644 --- a/tests/filter_warc/test_make_source.py +++ b/tests/filter_warc/test_make_source.py @@ -15,6 +15,7 @@ def make_args(**kw): target_source='cdx', engine=None, hostnames=None, + domains=None, query=None, query_file=None, athena_database=None, @@ -71,11 +72,39 @@ def test_sql_query_and_query_file_mutually_exclusive(tmp_path): build(target_source='sql', engine='athena', query='SELECT 1', query_file=str(f)) -def test_sql_neither_hostnames_nor_query(): +def test_sql_neither_hostnames_domains_nor_query(): with pytest.raises(ValueError): build(target_source='sql', engine='athena') +def test_sql_domains_and_query_mutually_exclusive(): + with pytest.raises(ValueError): + build(target_source='sql', engine='athena', domains=['example.com'], query='SELECT 1') + + +def test_athena_domains_only(): + src = build(target_source='sql', engine='athena', domains=['example.com']) + assert isinstance(src, AthenaSource) + assert 'url_host_registered_domain = \'example.com\'' in src.query + assert 'url_host_name' not in src.query + + +def test_athena_hostnames_and_domains_combined(): + src = build(target_source='sql', engine='athena', hostnames=['www.example.com'], domains=['example.org']) + assert "url_host_name = 'www.example.com'" in src.query + assert "url_host_registered_domain = 'example.org'" in src.query + # TLDs from both hostnames and domains + assert "url_host_tld = 'com'" in src.query + assert "url_host_tld = 'org'" in src.query + + +def test_duckdb_domains_only(): + src = build(target_source='sql', engine='duckdb', domains=['commoncrawl.org']) + q = src._build_query() + assert "url_host_registered_domain = 'commoncrawl.org'" in q + assert 'read_parquet' in q + + # --- athena --- def test_athena_requires_s3_output(): diff --git a/tests/filter_warc/test_sql_sources_gated.py b/tests/filter_warc/test_sql_sources_gated.py index f90990b..61fcde5 100644 --- a/tests/filter_warc/test_sql_sources_gated.py +++ b/tests/filter_warc/test_sql_sources_gated.py @@ -84,3 +84,22 @@ def test_repackage_sql_duckdb_e2e(tmpdir): '--confirm-cost', ], ) + + +@requires_aws_s3 +@requires_duckdb +def test_repackage_sql_duckdb_domain_e2e(tmpdir): + # Domain filtering (url_host_registered_domain) also matches subdomains; bound it + # with --limit to keep the live verification cheap. + _produce_and_consume( + tmpdir, + [ + '--crawl', CRAWL, + '--limit', '10', + 'repackage', + '--target-source=sql', + '--engine=duckdb', + '--domains', HOST, + '--confirm-cost', + ], + ) From db8e35ab8bc608865df388d198afded8dc36166c Mon Sep 17 00:00:00 2001 From: malteos Date: Wed, 17 Jun 2026 11:52:02 +0200 Subject: [PATCH 3/6] feat: carry extra SQL --query columns into range-jobs CSV A raw --query/--query-file can SELECT analysis columns (e.g. content_languages) beyond the required warc_filename/offset/length; those extra columns now flow through to the materialized range-jobs CSV (RangeJob.extra -> lazy CSV header). Rename the CSV fetch-job columns to the index's WARC names (warc_filename, warc_record_offset, warc_record_length; warc_url for self-contained) so they never collide with the index's own columns (e.g. the page-URL 'url' column). The CSV reader auto-detects mode from warc_url/warc_filename and warns+prefers warc_url if both are present. --- cdx_toolkit/filter_warc/data_classes.py | 6 +- cdx_toolkit/filter_warc/sources/athena.py | 8 +- cdx_toolkit/filter_warc/sources/csv.py | 85 ++++++++++++++----- cdx_toolkit/filter_warc/sources/duckdb.py | 13 ++- .../filter_warc/test_athena_query_builder.py | 52 +++++++++++- tests/filter_warc/test_command.py | 6 +- tests/filter_warc/test_csv_source.py | 84 +++++++++++++++++- tests/filter_warc/test_producer.py | 2 +- tests/filter_warc/test_sql_sources_gated.py | 32 +++++++ 9 files changed, 257 insertions(+), 31 deletions(-) diff --git a/cdx_toolkit/filter_warc/data_classes.py b/cdx_toolkit/filter_warc/data_classes.py index 40f190e..3c98955 100644 --- a/cdx_toolkit/filter_warc/data_classes.py +++ b/cdx_toolkit/filter_warc/data_classes.py @@ -1,5 +1,5 @@ import time -from dataclasses import dataclass +from dataclasses import dataclass, field from cdx_toolkit.filter_warc.s3_utils import is_s3_url, parse_s3_uri, with_retries from typing import Optional, Tuple @@ -51,6 +51,10 @@ class RangeJob: # used when materializing a non-self-contained range-jobs CSV. `url` stays # authoritative for fetching. filename: Optional[str] = None + # Extra columns from a raw SQL --query (e.g. content_languages), kept only for + # CSV materialization / analysis. Never used for fetching. Excluded from + # equality/hash so RangeJob stays hashable despite the dict value. + extra: Optional[dict] = field(default=None, compare=False) def is_s3(self): return is_s3_url(self.url) diff --git a/cdx_toolkit/filter_warc/sources/athena.py b/cdx_toolkit/filter_warc/sources/athena.py index 203048d..9425a3a 100644 --- a/cdx_toolkit/filter_warc/sources/athena.py +++ b/cdx_toolkit/filter_warc/sources/athena.py @@ -4,7 +4,11 @@ from cdx_toolkit.filter_warc.data_classes import RangeJob from cdx_toolkit.filter_warc.sources.base import RangeJobSource, CostEstimate -from cdx_toolkit.filter_warc.sources.sql_base import validate_result_columns, join_warc_url +from cdx_toolkit.filter_warc.sources.sql_base import ( + validate_result_columns, + join_warc_url, + REQUIRED_RESULT_COLUMNS, +) logger = logging.getLogger(__name__) @@ -162,12 +166,14 @@ def iter_range_jobs(client, query_execution_id: str, warc_download_prefix: Optio warc_filename = row['warc_filename'] warc_url = join_warc_url(warc_download_prefix, warc_filename) + extra = {k: v for k, v in row.items() if k not in REQUIRED_RESULT_COLUMNS} yield RangeJob( url=warc_url, offset=int(row['warc_record_offset']), length=int(row['warc_record_length']), filename=warc_filename, + extra=extra or None, ) diff --git a/cdx_toolkit/filter_warc/sources/csv.py b/cdx_toolkit/filter_warc/sources/csv.py index f9ac5aa..d7264df 100644 --- a/cdx_toolkit/filter_warc/sources/csv.py +++ b/cdx_toolkit/filter_warc/sources/csv.py @@ -11,39 +11,70 @@ logger = logging.getLogger(__name__) -FILENAME_FIELDS = ['filename', 'offset', 'length'] -URL_FIELDS = ['url', 'offset', 'length'] +# Fetch-job columns name the WARC *location* (mirroring the columnar index's column +# names) so they never collide with extra analysis columns a raw --query may SELECT +# -- e.g. the index's own `url` (the captured page URL) vs `warc_url` (download URL). +FILENAME_FIELDS = ['warc_filename', 'warc_record_offset', 'warc_record_length'] +URL_FIELDS = ['warc_url', 'warc_record_offset', 'warc_record_length'] + +# Columns the reader interprets as the fetch job; everything else round-trips as extra. +_KNOWN_FIELDS = {'warc_url', 'warc_filename', 'warc_record_offset', 'warc_record_length'} class RangeJobCsvWriter: """Write RangeJobs to a CSV (local or remote via fsspec). - Default mode writes the relative `filename` column (the consumer prepends the - WARC download prefix); `self_contained` mode writes the full `url` column.""" + Default mode writes the relative `warc_filename` column (the consumer prepends + the WARC download prefix); `self_contained` mode writes the full `warc_url` + column. Any extra columns carried on a RangeJob (from a raw SQL --query) are + appended after the fetch-job columns. The header is written lazily from the + first job so its extra columns can be discovered.""" def __init__(self, path: str, self_contained: bool = False): self.path = path self.self_contained = self_contained - self._fields = URL_FIELDS if self_contained else FILENAME_FIELDS + self._base_fields = URL_FIELDS if self_contained else FILENAME_FIELDS self._ctx = fsspec.open(path, 'wt', newline='') self._fh = self._ctx.__enter__() - self._writer = csv.DictWriter(self._fh, fieldnames=self._fields) + self._writer = None + + def _init_writer(self, fieldnames) -> None: + self._writer = csv.DictWriter( + self._fh, fieldnames=fieldnames, extrasaction='ignore', restval='' + ) self._writer.writeheader() def write(self, job: RangeJob) -> None: if self.self_contained: - row = {'url': job.url, 'offset': job.offset, 'length': job.length} + row = { + 'warc_url': job.url, + 'warc_record_offset': job.offset, + 'warc_record_length': job.length, + } else: if job.filename is None: raise ValueError( 'cannot write a non-self-contained range-jobs CSV: RangeJob.filename is ' 'missing; pass --csv-self-contained to write full URLs instead' ) - row = {'filename': job.filename, 'offset': job.offset, 'length': job.length} + row = { + 'warc_filename': job.filename, + 'warc_record_offset': job.offset, + 'warc_record_length': job.length, + } + if job.extra: + row.update(job.extra) + + if self._writer is None: + extra_fields = [k for k in (job.extra or {}) if k not in self._base_fields] + self._init_writer(self._base_fields + extra_fields) self._writer.writerow(row) def close(self) -> None: if self._ctx is not None: + # No rows written: still emit the base header for a valid (empty) file. + if self._writer is None: + self._init_writer(self._base_fields) self._ctx.__exit__(None, None, None) self._ctx = None self._fh = None @@ -52,9 +83,11 @@ def close(self) -> None: class CsvSource(RangeJobSource): """RangeJobs read from a CSV/TSV. - Mode is auto-detected from the header: a `url` column => self-contained (used - as-is); a `filename` column => the WARC download prefix is prepended. TSV is - detected from a `.tsv`/`.tsv.gz` extension; `.gz` inputs are decompressed.""" + Mode is auto-detected from the header: a `warc_url` column => self-contained + (used as-is); a `warc_filename` column => the WARC download prefix is prepended. + If both are present, a warning is logged and `warc_url` is used. Any other + columns round-trip onto RangeJob.extra. TSV is detected from a `.tsv`/`.tsv.gz` + extension; `.gz` inputs are decompressed.""" def __init__(self, path: str, warc_download_prefix: Optional[str]): self.path = path @@ -68,23 +101,35 @@ def iter_range_jobs(self) -> Iterator[RangeJob]: with fsspec.open(self.path, 'rt', newline='', compression=compression) as fh: reader = csv.DictReader(fh, delimiter=delimiter) fields = set(reader.fieldnames or []) - if 'url' in fields: + has_url = 'warc_url' in fields + has_filename = 'warc_filename' in fields + if has_url and has_filename: + logger.warning( + 'range-jobs CSV %s has both `warc_url` and `warc_filename` columns; ' + 'using `warc_url` for fetching', self.path + ) mode_url = True - elif 'filename' in fields: + elif has_url: + mode_url = True + elif has_filename: mode_url = False else: raise ValueError( - f'range-jobs CSV {self.path} must have a `url` or `filename` column ' + f'range-jobs CSV {self.path} must have a `warc_url` or `warc_filename` column ' f'(got header: {reader.fieldnames})' ) for row in reader: - offset = int(row['offset']) - length = int(row['length']) + offset = int(row['warc_record_offset']) + length = int(row['warc_record_length']) if mode_url: - url = row['url'] - filename = None + url = row['warc_url'] + filename = row.get('warc_filename') else: - filename = row['filename'] + filename = row['warc_filename'] url = join_warc_url(self.warc_download_prefix, filename) - yield RangeJob(url=url, offset=offset, length=length, filename=filename) + extra = {k: v for k, v in row.items() if k not in _KNOWN_FIELDS} + yield RangeJob( + url=url, offset=offset, length=length, + filename=filename, extra=extra or None, + ) diff --git a/cdx_toolkit/filter_warc/sources/duckdb.py b/cdx_toolkit/filter_warc/sources/duckdb.py index 2513bd2..b794614 100644 --- a/cdx_toolkit/filter_warc/sources/duckdb.py +++ b/cdx_toolkit/filter_warc/sources/duckdb.py @@ -3,7 +3,12 @@ from cdx_toolkit.filter_warc.data_classes import RangeJob from cdx_toolkit.filter_warc.sources.base import RangeJobSource, CostEstimate -from cdx_toolkit.filter_warc.sources.sql_base import build_sql, validate_result_columns, join_warc_url +from cdx_toolkit.filter_warc.sources.sql_base import ( + build_sql, + validate_result_columns, + join_warc_url, + REQUIRED_RESULT_COLUMNS, +) logger = logging.getLogger(__name__) @@ -103,11 +108,17 @@ def iter_range_jobs(self) -> Iterator[RangeJob]: for row in rows: warc_filename = row[idx['warc_filename']] warc_url = join_warc_url(self.warc_download_prefix, warc_filename) + extra = { + name: row[i] + for i, name in enumerate(col_names) + if name not in REQUIRED_RESULT_COLUMNS + } yield RangeJob( url=warc_url, offset=int(row[idx['warc_record_offset']]), length=int(row[idx['warc_record_length']]), filename=warc_filename, + extra=extra or None, ) finally: con.close() diff --git a/tests/filter_warc/test_athena_query_builder.py b/tests/filter_warc/test_athena_query_builder.py index 304809e..256af68 100644 --- a/tests/filter_warc/test_athena_query_builder.py +++ b/tests/filter_warc/test_athena_query_builder.py @@ -6,7 +6,7 @@ validate_result_columns, join_warc_url, ) -from cdx_toolkit.filter_warc.sources.athena import run_athena_query +from cdx_toolkit.filter_warc.sources.athena import run_athena_query, iter_range_jobs class _FakeAthenaClient: @@ -130,6 +130,56 @@ def test_join_warc_url_absolute_filename_passthrough(): assert join_warc_url('', 'https://host/x.warc.gz') == 'https://host/x.warc.gz' +class _FakePaginator: + """Yields Athena get_query_results-style pages (first row is the header).""" + + def __init__(self, pages): + self._pages = pages + + def paginate(self, **kwargs): + return iter(self._pages) + + +class _FakeResultsClient: + def __init__(self, pages): + self._pages = pages + + def get_paginator(self, name): + assert name == 'get_query_results' + return _FakePaginator(self._pages) + + +def _athena_page(columns, *value_rows): + header = {'Data': [{'VarCharValue': c} for c in columns]} + data = [{'Data': [{'VarCharValue': v} for v in row]} for row in value_rows] + return {'ResultSet': {'Rows': [header] + data}} + + +def test_iter_range_jobs_carries_extra_columns(): + pages = [_athena_page( + ['warc_filename', 'warc_record_offset', 'warc_record_length', 'content_languages', 'url'], + ['crawl-data/x.warc.gz', '100', '200', 'eng', 'https://example.com/page'], + )] + client = _FakeResultsClient(pages) + jobs = list(iter_range_jobs(client, 'qid', 'https://data.commoncrawl.org')) + assert len(jobs) == 1 + job = jobs[0] + assert job.url == 'https://data.commoncrawl.org/crawl-data/x.warc.gz' + assert job.offset == 100 and job.length == 200 + assert job.filename == 'crawl-data/x.warc.gz' + # extra carries only the non-required columns (page URL distinct from warc URL) + assert job.extra == {'content_languages': 'eng', 'url': 'https://example.com/page'} + + +def test_iter_range_jobs_no_extra_columns_is_none(): + pages = [_athena_page( + ['warc_filename', 'warc_record_offset', 'warc_record_length'], + ['crawl-data/x.warc.gz', '1', '2'], + )] + jobs = list(iter_range_jobs(_FakeResultsClient(pages), 'qid', 'https://data.commoncrawl.org')) + assert jobs[0].extra is None + + def test_run_athena_query_logs_sql(caplog): import logging client = _FakeAthenaClient() diff --git a/tests/filter_warc/test_command.py b/tests/filter_warc/test_command.py index 4da1233..ad2f164 100644 --- a/tests/filter_warc/test_command.py +++ b/tests/filter_warc/test_command.py @@ -171,7 +171,7 @@ def _assert_repackaged_warc(warc_path, metadata_record_path): def test_repackage_csv_materialize_filename(tmpdir): """--no-fetch produces a filename-based range-jobs CSV without fetching WARCs.""" csv_path, rows = _produce_range_jobs_csv(tmpdir, 'ranges.csv') - assert set(rows[0].keys()) == {'filename', 'offset', 'length'} + assert set(rows[0].keys()) == {'warc_filename', 'warc_record_offset', 'warc_record_length'} assert len(rows) == 10 # No WARC was written for the default --prefix assert not any(name.endswith('.warc.gz') for name in os.listdir(str(tmpdir))) @@ -180,9 +180,9 @@ def test_repackage_csv_materialize_filename(tmpdir): def test_repackage_csv_materialize_self_contained(tmpdir): """--csv-self-contained produces a url-based range-jobs CSV.""" csv_path, rows = _produce_range_jobs_csv(tmpdir, 'ranges_url.csv', self_contained=True) - assert set(rows[0].keys()) == {'url', 'offset', 'length'} + assert set(rows[0].keys()) == {'warc_url', 'warc_record_offset', 'warc_record_length'} assert len(rows) == 10 - assert rows[0]['url'].startswith('https://data.commoncrawl.org/') + assert rows[0]['warc_url'].startswith('https://data.commoncrawl.org/') def test_cli_repackage_csv_roundtrip(tmpdir): diff --git a/tests/filter_warc/test_csv_source.py b/tests/filter_warc/test_csv_source.py index d6a6a84..198c530 100644 --- a/tests/filter_warc/test_csv_source.py +++ b/tests/filter_warc/test_csv_source.py @@ -17,7 +17,11 @@ def test_writer_filename_mode_and_read(tmp_path): with open(path, newline='') as f: rows = list(csv.DictReader(f)) - assert rows[0] == {'filename': 'crawl-data/x.warc.gz', 'offset': '10', 'length': '20'} + assert rows[0] == { + 'warc_filename': 'crawl-data/x.warc.gz', + 'warc_record_offset': '10', + 'warc_record_length': '20', + } jobs = list(CsvSource(path, 'https://data.commoncrawl.org').iter_range_jobs()) assert jobs[0].url == 'https://data.commoncrawl.org/crawl-data/x.warc.gz' @@ -33,7 +37,11 @@ def test_writer_self_contained_mode_and_read(tmp_path): with open(path, newline='') as f: rows = list(csv.DictReader(f)) - assert rows[0] == {'url': 's3://commoncrawl/crawl-data/x.warc.gz', 'offset': '5', 'length': '7'} + assert rows[0] == { + 'warc_url': 's3://commoncrawl/crawl-data/x.warc.gz', + 'warc_record_offset': '5', + 'warc_record_length': '7', + } # url used as-is regardless of the (ignored) prefix jobs = list(CsvSource(path, 'https://ignored').iter_range_jobs()) @@ -59,8 +67,78 @@ def test_csv_source_missing_columns(tmp_path): def test_csv_source_tsv(tmp_path): path = str(tmp_path / 'ranges.tsv') with open(path, 'w') as f: - f.write('filename\toffset\tlength\n') + f.write('warc_filename\twarc_record_offset\twarc_record_length\n') f.write('crawl-data/x.warc.gz\t10\t20\n') jobs = list(CsvSource(path, 'https://data.commoncrawl.org').iter_range_jobs()) assert jobs[0].url == 'https://data.commoncrawl.org/crawl-data/x.warc.gz' assert jobs[0].offset == 10 and jobs[0].length == 20 + + +def test_writer_extra_columns_filename_mode(tmp_path): + path = str(tmp_path / 'ranges.csv') + w = RangeJobCsvWriter(path, self_contained=False) + w.write(RangeJob( + url='https://data.commoncrawl.org/crawl-data/x.warc.gz', + offset=10, length=20, filename='crawl-data/x.warc.gz', + extra={'content_languages': 'eng', 'url': 'https://example.com/page'}, + )) + w.close() + + with open(path, newline='') as f: + reader = csv.DictReader(f) + assert reader.fieldnames == [ + 'warc_filename', 'warc_record_offset', 'warc_record_length', + 'content_languages', 'url', + ] + rows = list(reader) + assert rows[0]['content_languages'] == 'eng' + assert rows[0]['url'] == 'https://example.com/page' + + # extras round-trip back onto RangeJob.extra + jobs = list(CsvSource(path, 'https://data.commoncrawl.org').iter_range_jobs()) + assert jobs[0].extra == {'content_languages': 'eng', 'url': 'https://example.com/page'} + assert jobs[0].filename == 'crawl-data/x.warc.gz' + + +def test_writer_extra_columns_self_contained_mode(tmp_path): + path = str(tmp_path / 'ranges_url.csv') + w = RangeJobCsvWriter(path, self_contained=True) + w.write(RangeJob( + url='https://data.commoncrawl.org/crawl-data/x.warc.gz', + offset=1, length=2, filename='crawl-data/x.warc.gz', + extra={'content_languages': 'fra'}, + )) + w.close() + + with open(path, newline='') as f: + reader = csv.DictReader(f) + assert reader.fieldnames == [ + 'warc_url', 'warc_record_offset', 'warc_record_length', 'content_languages', + ] + rows = list(reader) + assert rows[0]['content_languages'] == 'fra' + + +def test_writer_empty_writes_base_header(tmp_path): + path = str(tmp_path / 'empty.csv') + w = RangeJobCsvWriter(path, self_contained=False) + w.close() + with open(path, newline='') as f: + reader = csv.DictReader(f) + rows = list(reader) + assert reader.fieldnames == ['warc_filename', 'warc_record_offset', 'warc_record_length'] + assert rows == [] + + +def test_csv_source_both_url_and_filename_warns(tmp_path, caplog): + import logging + path = str(tmp_path / 'both.csv') + with open(path, 'w') as f: + f.write('warc_url,warc_filename,warc_record_offset,warc_record_length\n') + f.write('s3://commoncrawl/crawl-data/x.warc.gz,crawl-data/x.warc.gz,5,7\n') + with caplog.at_level(logging.WARNING): + jobs = list(CsvSource(path, 'https://ignored').iter_range_jobs()) + assert 'both' in caplog.text.lower() + # warc_url wins for fetching + assert jobs[0].url == 's3://commoncrawl/crawl-data/x.warc.gz' + assert jobs[0].filename == 'crawl-data/x.warc.gz' diff --git a/tests/filter_warc/test_producer.py b/tests/filter_warc/test_producer.py index ba6efcd..7384f41 100644 --- a/tests/filter_warc/test_producer.py +++ b/tests/filter_warc/test_producer.py @@ -39,7 +39,7 @@ def test_no_fetch_materializes_csv(tmp_path): with open(out, newline='') as f: rows = list(csv.DictReader(f)) assert len(rows) == 3 - assert set(rows[0].keys()) == {'filename', 'offset', 'length'} + assert set(rows[0].keys()) == {'warc_filename', 'warc_record_offset', 'warc_record_length'} def test_producer_emits_stops_even_when_source_raises(tmp_path): diff --git a/tests/filter_warc/test_sql_sources_gated.py b/tests/filter_warc/test_sql_sources_gated.py index 61fcde5..8993fcd 100644 --- a/tests/filter_warc/test_sql_sources_gated.py +++ b/tests/filter_warc/test_sql_sources_gated.py @@ -86,6 +86,38 @@ def test_repackage_sql_duckdb_e2e(tmpdir): ) +@requires_aws_s3 +@requires_duckdb +def test_repackage_sql_duckdb_extra_columns_e2e(tmpdir): + # A raw --query that SELECTs an extra analysis column (content_languages) should + # carry that column through to the materialized range-jobs CSV. Single crawl + + # LIMIT keeps the scan cheap. + csv_path = os.path.join(str(tmpdir), 'ranges.csv') + query = ( + 'SELECT warc_filename, warc_record_offset, warc_record_length, content_languages ' + "FROM read_parquet(" + f"'s3://commoncrawl/cc-index/table/cc-main/warc/crawl={CRAWL}/subset=warc/*.parquet', " + 'hive_partitioning=true) ' + f"WHERE url_host_registered_domain = '{HOST}' LIMIT 10" + ) + main(args=[ + 'repackage', + '--target-source=sql', + '--engine=duckdb', + f'--query={query}', + f'--range-jobs-output={csv_path}', + '--no-fetch', + '--confirm-cost', + ]) + + with open(csv_path, newline='') as f: + reader = csv.DictReader(f) + assert 'content_languages' in (reader.fieldnames or []) + rows = list(reader) + assert len(rows) > 0 + assert any(r.get('content_languages') for r in rows), 'expected a content_languages value' + + @requires_aws_s3 @requires_duckdb def test_repackage_sql_duckdb_domain_e2e(tmpdir): From 05b929c2b40d15560c48ce45d13c92ed4cfa8296 Mon Sep 17 00:00:00 2001 From: malteos Date: Wed, 17 Jun 2026 13:50:00 +0200 Subject: [PATCH 4/6] feat: sort range jobs by (warc_filename, warc_record_offset) for read locality Group records of the same WARC file with ascending offsets to improve S3 range-read locality (and to enable later coalescing). Applied as an ORDER BY on guided SQL queries (Athena + DuckDB) and as an in-memory sort when loading a CSV source. Raw --query and cdx sources keep their order. New --no-sort-ranges opts out (e.g. already-sorted / very large CSV, or to preserve original order). --- cdx_toolkit/filter_warc/args.py | 11 ++++- cdx_toolkit/filter_warc/sources/csv.py | 23 ++++++++-- cdx_toolkit/filter_warc/sources/duckdb.py | 4 +- cdx_toolkit/filter_warc/sources/factory.py | 14 +++++-- cdx_toolkit/filter_warc/sources/sql_base.py | 15 ++++++- .../filter_warc/test_athena_query_builder.py | 14 +++++++ tests/filter_warc/test_command.py | 6 ++- tests/filter_warc/test_csv_source.py | 42 +++++++++++++++++++ tests/filter_warc/test_make_source.py | 36 ++++++++++++++++ 9 files changed, 153 insertions(+), 12 deletions(-) diff --git a/cdx_toolkit/filter_warc/args.py b/cdx_toolkit/filter_warc/args.py index 71afa90..b40d731 100644 --- a/cdx_toolkit/filter_warc/args.py +++ b/cdx_toolkit/filter_warc/args.py @@ -106,7 +106,16 @@ def add_repackage_args(parser: argparse.ArgumentParser): parser.add_argument( '--csv-self-contained', action='store_true', - help='Write full URLs (url,offset,length) to --range-jobs-output instead of relative filenames.', + help='Write full URLs (warc_url,...) to --range-jobs-output instead of relative filenames.', + ) + parser.add_argument( + '--no-sort-ranges', + action='store_true', + help=('Do not sort range jobs by (warc_filename, warc_record_offset) before fetching. ' + 'Sorting (the default) groups records of the same WARC file with ascending offsets ' + 'for better S3 range-read locality; it adds an ORDER BY to a guided SQL query and ' + 'buffers a CSV source in memory. Disable for an already-sorted or very large CSV, or ' + 'to preserve a raw query/CSV order.'), ) parser.add_argument('--prefix', default='TEST', help='prefix for the output warc filename') parser.add_argument( diff --git a/cdx_toolkit/filter_warc/sources/csv.py b/cdx_toolkit/filter_warc/sources/csv.py index d7264df..5f6cfd2 100644 --- a/cdx_toolkit/filter_warc/sources/csv.py +++ b/cdx_toolkit/filter_warc/sources/csv.py @@ -87,17 +87,24 @@ class CsvSource(RangeJobSource): (used as-is); a `warc_filename` column => the WARC download prefix is prepended. If both are present, a warning is logged and `warc_url` is used. Any other columns round-trip onto RangeJob.extra. TSV is detected from a `.tsv`/`.tsv.gz` - extension; `.gz` inputs are decompressed.""" + extension; `.gz` inputs are decompressed. - def __init__(self, path: str, warc_download_prefix: Optional[str]): + With `sort` (the default), jobs are sorted by (WARC file, record offset) before + being emitted -- grouping same-file records with ascending offsets for better S3 + range-read locality. This buffers all rows in memory; pass sort=False to stream + an already-sorted file without buffering.""" + + def __init__(self, path: str, warc_download_prefix: Optional[str], sort: bool = True): self.path = path self.warc_download_prefix = warc_download_prefix + self.sort = sort def iter_range_jobs(self) -> Iterator[RangeJob]: path = str(self.path) delimiter = '\t' if path.endswith(('.tsv', '.tsv.gz')) else ',' compression = 'gzip' if path.endswith('.gz') else None + buffered = [] if self.sort else None with fsspec.open(self.path, 'rt', newline='', compression=compression) as fh: reader = csv.DictReader(fh, delimiter=delimiter) fields = set(reader.fieldnames or []) @@ -129,7 +136,17 @@ def iter_range_jobs(self) -> Iterator[RangeJob]: filename = row['warc_filename'] url = join_warc_url(self.warc_download_prefix, filename) extra = {k: v for k, v in row.items() if k not in _KNOWN_FIELDS} - yield RangeJob( + job = RangeJob( url=url, offset=offset, length=length, filename=filename, extra=extra or None, ) + if buffered is None: + yield job + else: + buffered.append(job) + + if buffered is not None: + # Group by WARC file (filename when available, else the full url) then offset. + buffered.sort(key=lambda j: (j.filename or j.url or '', j.offset)) + for job in buffered: + yield job diff --git a/cdx_toolkit/filter_warc/sources/duckdb.py b/cdx_toolkit/filter_warc/sources/duckdb.py index b794614..8886ef4 100644 --- a/cdx_toolkit/filter_warc/sources/duckdb.py +++ b/cdx_toolkit/filter_warc/sources/duckdb.py @@ -50,6 +50,7 @@ def __init__( index_path: str, warc_download_prefix: Optional[str], limit: int = 0, + sort: bool = True, region_name: str = 'us-east-1', ): self.raw_query = query @@ -59,6 +60,7 @@ def __init__( self.index_path = index_path self.warc_download_prefix = warc_download_prefix self.limit = limit + self.sort = sort self.region_name = region_name def estimate_cost(self) -> CostEstimate: @@ -73,7 +75,7 @@ def _build_query(self) -> str: # crawl pruning is done in the FROM glob, so no crawl IN (...) in the WHERE return build_sql( from_clause, self.hostnames, crawls=None, limit=self.limit, - url_host_registered_domains=self.domains, + url_host_registered_domains=self.domains, order_by=self.sort, ) def iter_range_jobs(self) -> Iterator[RangeJob]: diff --git a/cdx_toolkit/filter_warc/sources/factory.py b/cdx_toolkit/filter_warc/sources/factory.py index ba0e098..a7251ef 100644 --- a/cdx_toolkit/filter_warc/sources/factory.py +++ b/cdx_toolkit/filter_warc/sources/factory.py @@ -15,6 +15,10 @@ def make_source(args, *, warc_download_prefix: Optional[str], record_limit: int) Centralises all source/engine validation (engine required iff sql; hostnames/query/query-file mutual exclusivity; required connection options).""" target = args.target_source + # Sort range jobs by (warc_filename, warc_record_offset) for fetch-time read + # locality unless explicitly disabled. Applies to the guided SQL query (ORDER BY) + # and to CSV loading; cdx files are out of scope (already SURT-ordered). + sort = not getattr(args, 'no_sort_ranges', False) if target == 'cdx': from cdx_toolkit.filter_warc.sources.cdx import CdxSource @@ -25,10 +29,10 @@ def make_source(args, *, warc_download_prefix: Optional[str], record_limit: int) from cdx_toolkit.filter_warc.sources.csv import CsvSource if not args.csv_path: raise ValueError('--csv-path is required for --target-source csv') - return CsvSource(args.csv_path, warc_download_prefix) + return CsvSource(args.csv_path, warc_download_prefix, sort=sort) if target == 'sql': - return _make_sql_source(args, warc_download_prefix, record_limit) + return _make_sql_source(args, warc_download_prefix, record_limit, sort=sort) raise ValueError(f'Invalid target source: {target} (available: cdx, sql, csv)') @@ -59,7 +63,7 @@ def _resolve_sql_query_spec(args) -> Tuple[Optional[str], Optional[list], Option return None, args.hostnames, args.domains, crawls -def _make_sql_source(args, warc_download_prefix, record_limit) -> RangeJobSource: +def _make_sql_source(args, warc_download_prefix, record_limit, *, sort: bool = True) -> RangeJobSource: engine = args.engine if not engine: raise ValueError('--engine is required for --target-source sql (choices: athena, duckdb)') @@ -72,8 +76,11 @@ def _make_sql_source(args, warc_download_prefix, record_limit) -> RangeJobSource if not args.athena_s3_output: raise ValueError('--athena-s3-output is required for --engine athena') database = args.athena_database or 'ccindex' + # A raw --query is the user's responsibility to order; only the guided query + # gets the ORDER BY (warc_filename, warc_record_offset) for read locality. query = raw_sql if raw_sql else build_athena_query( hostnames, crawls=crawls, limit=limit, url_host_registered_domains=domains, + order_by=sort, ) n_crawls = None if raw_sql else (len(crawls) if crawls else None) return AthenaSource( @@ -94,6 +101,7 @@ def _make_sql_source(args, warc_download_prefix, record_limit) -> RangeJobSource index_path=args.duckdb_index_path, warc_download_prefix=warc_download_prefix, limit=limit, + sort=sort, ) raise ValueError(f'Invalid --engine: {engine} (choices: athena, duckdb)') diff --git a/cdx_toolkit/filter_warc/sources/sql_base.py b/cdx_toolkit/filter_warc/sources/sql_base.py index e466682..3b5dcb6 100644 --- a/cdx_toolkit/filter_warc/sources/sql_base.py +++ b/cdx_toolkit/filter_warc/sources/sql_base.py @@ -69,25 +69,34 @@ def build_where_sql( return '\n AND '.join(clauses) +# Sorting range jobs by (warc_filename, warc_record_offset) groups records of the +# same WARC file with ascending offsets, which improves S3 range-read locality (and +# enables coalescing adjacent ranges) at fetch time. +ORDER_BY_SQL = 'ORDER BY warc_filename, warc_record_offset' + + def build_sql( from_clause: str, url_host_names: Optional[List[str]] = None, crawls: Optional[List[str]] = None, limit: int = 0, url_host_registered_domains: Optional[List[str]] = None, + order_by: bool = True, ) -> str: """Assemble a full SELECT for the columnar index. `from_clause` is the text following FROM (e.g. `ccindex` for Athena, or a - `read_parquet(...)` expression for DuckDB).""" + `read_parquet(...)` expression for DuckDB). When `order_by` is set, results are + sorted by (warc_filename, warc_record_offset) for fetch-time read locality.""" where_sql = build_where_sql(url_host_names, crawls, url_host_registered_domains=url_host_registered_domains) + order_sql = f'\n {ORDER_BY_SQL}' if order_by else '' limit_sql = f'\n LIMIT {limit}' if limit and limit > 0 else '' return f""" SELECT warc_filename, warc_record_offset, warc_record_length FROM {from_clause} - WHERE {where_sql}{limit_sql}""" + WHERE {where_sql}{order_sql}{limit_sql}""" def build_athena_query( @@ -96,11 +105,13 @@ def build_athena_query( limit: int = 0, table: str = 'ccindex', url_host_registered_domains: Optional[List[str]] = None, + order_by: bool = True, ) -> str: """Athena flavour of build_sql (FROM
).""" return build_sql( table, url_host_names, crawls=crawls, limit=limit, url_host_registered_domains=url_host_registered_domains, + order_by=order_by, ) diff --git a/tests/filter_warc/test_athena_query_builder.py b/tests/filter_warc/test_athena_query_builder.py index 256af68..37bf6e3 100644 --- a/tests/filter_warc/test_athena_query_builder.py +++ b/tests/filter_warc/test_athena_query_builder.py @@ -73,6 +73,20 @@ def test_build_query_limit(): assert 'LIMIT' not in build_athena_query(['example.com'], limit=0) +def test_build_query_orders_by_default(): + q = build_athena_query(['example.com']) + assert 'ORDER BY warc_filename, warc_record_offset' in q + + +def test_build_query_order_by_false(): + assert 'ORDER BY' not in build_athena_query(['example.com'], order_by=False) + + +def test_build_query_order_by_precedes_limit(): + q = build_athena_query(['example.com'], limit=10) + assert q.index('ORDER BY') < q.index('LIMIT') + + def test_build_query_requires_hostnames(): with pytest.raises(ValueError): build_athena_query([]) diff --git a/tests/filter_warc/test_command.py b/tests/filter_warc/test_command.py index ad2f164..5648217 100644 --- a/tests/filter_warc/test_command.py +++ b/tests/filter_warc/test_command.py @@ -160,8 +160,10 @@ def _assert_repackaged_warc(warc_path, metadata_record_path): metadata_record_headers = record.rec_headers assert len(response_records) == 10, 'Invalid record count' - assert 'Catalogue en ligne Mission de France' in response_contents[0], 'Invalid response content' - assert 'dojo/dijit/themes/tundra/tundra' in response_contents[9], 'Invalid response content' + # CsvSource sorts by (warc_filename, offset) by default, so assert content + # presence independent of record order. + assert any('Catalogue en ligne Mission de France' in c for c in response_contents), 'Invalid response content' + assert any('dojo/dijit/themes/tundra/tundra' in c for c in response_contents), 'Invalid response content' assert metadata_record is not None, 'Metadata record not set' assert metadata_record_headers.get('WARC-Payload-Digest') == 'sha1:VXA2A5YUS3TAY36AUO6MACRMNOH5RXG2', ( 'Invalid metadata block digest' diff --git a/tests/filter_warc/test_csv_source.py b/tests/filter_warc/test_csv_source.py index 198c530..b48aa63 100644 --- a/tests/filter_warc/test_csv_source.py +++ b/tests/filter_warc/test_csv_source.py @@ -130,6 +130,48 @@ def test_writer_empty_writes_base_header(tmp_path): assert rows == [] +def test_csv_source_sorts_by_filename_then_offset(tmp_path): + path = str(tmp_path / 'ranges.csv') + with open(path, 'w') as f: + f.write('warc_filename,warc_record_offset,warc_record_length\n') + # deliberately out of order: file b before a, and offsets descending + f.write('b.warc.gz,50,1\n') + f.write('a.warc.gz,200,1\n') + f.write('a.warc.gz,10,1\n') + jobs = list(CsvSource(path, 'https://data.commoncrawl.org').iter_range_jobs()) + assert [(j.filename, j.offset) for j in jobs] == [ + ('a.warc.gz', 10), ('a.warc.gz', 200), ('b.warc.gz', 50), + ] + + +def test_csv_source_no_sort_preserves_order(tmp_path): + path = str(tmp_path / 'ranges.csv') + with open(path, 'w') as f: + f.write('warc_filename,warc_record_offset,warc_record_length\n') + f.write('b.warc.gz,50,1\n') + f.write('a.warc.gz,200,1\n') + f.write('a.warc.gz,10,1\n') + jobs = list(CsvSource(path, 'https://data.commoncrawl.org', sort=False).iter_range_jobs()) + assert [(j.filename, j.offset) for j in jobs] == [ + ('b.warc.gz', 50), ('a.warc.gz', 200), ('a.warc.gz', 10), + ] + + +def test_csv_source_sorts_self_contained_by_url(tmp_path): + path = str(tmp_path / 'ranges.csv') + with open(path, 'w') as f: + f.write('warc_url,warc_record_offset,warc_record_length\n') + f.write('s3://commoncrawl/b.warc.gz,5,1\n') + f.write('s3://commoncrawl/a.warc.gz,9,1\n') + f.write('s3://commoncrawl/a.warc.gz,2,1\n') + jobs = list(CsvSource(path, 'https://ignored').iter_range_jobs()) + assert [(j.url, j.offset) for j in jobs] == [ + ('s3://commoncrawl/a.warc.gz', 2), + ('s3://commoncrawl/a.warc.gz', 9), + ('s3://commoncrawl/b.warc.gz', 5), + ] + + def test_csv_source_both_url_and_filename_warns(tmp_path, caplog): import logging path = str(tmp_path / 'both.csv') diff --git a/tests/filter_warc/test_make_source.py b/tests/filter_warc/test_make_source.py index ee60199..0dba91d 100644 --- a/tests/filter_warc/test_make_source.py +++ b/tests/filter_warc/test_make_source.py @@ -25,6 +25,7 @@ def make_args(**kw): cdx_path=None, cdx_glob=None, crawl=None, + no_sort_ranges=False, ) defaults.update(kw) return Namespace(**defaults) @@ -159,3 +160,38 @@ def test_duckdb_no_crawl_unbounded(): src = build(target_source='sql', engine='duckdb', hostnames=['commoncrawl.org']) assert src.estimate_cost().n_crawls is None assert 'crawl=*' in src._build_query() + + +# --- sort by (warc_filename, warc_record_offset) --- + +def test_athena_built_query_orders_by_default(): + src = build(target_source='sql', engine='athena', hostnames=['example.com']) + assert 'ORDER BY warc_filename, warc_record_offset' in src.query + + +def test_athena_built_query_no_sort(): + src = build(target_source='sql', engine='athena', hostnames=['example.com'], no_sort_ranges=True) + assert 'ORDER BY' not in src.query + + +def test_athena_raw_query_not_reordered(): + # a raw query is the user's responsibility; we must not inject ORDER BY + src = build(target_source='sql', engine='athena', query='SELECT warc_filename FROM x') + assert 'ORDER BY' not in src.query + + +def test_duckdb_built_query_orders_by_default(): + src = build(target_source='sql', engine='duckdb', hostnames=['commoncrawl.org']) + assert 'ORDER BY warc_filename, warc_record_offset' in src._build_query() + + +def test_duckdb_built_query_no_sort(): + src = build(target_source='sql', engine='duckdb', hostnames=['commoncrawl.org'], no_sort_ranges=True) + assert 'ORDER BY' not in src._build_query() + + +def test_csv_source_sort_flag(): + src = build(target_source='csv', csv_path='/tmp/ranges.csv') + assert src.sort is True + src = build(target_source='csv', csv_path='/tmp/ranges.csv', no_sort_ranges=True) + assert src.sort is False From 53a7ef76cc26d5557e52a1abb8c3aff7b8213a8e Mon Sep 17 00:00:00 2001 From: malteos Date: Fri, 19 Jun 2026 09:36:31 +0200 Subject: [PATCH 5/6] add notes for warc fetcher performance --- docs/notes/warc-fetcher-performance.md | 182 +++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 docs/notes/warc-fetcher-performance.md diff --git a/docs/notes/warc-fetcher-performance.md b/docs/notes/warc-fetcher-performance.md new file mode 100644 index 0000000..6f06db2 --- /dev/null +++ b/docs/notes/warc-fetcher-performance.md @@ -0,0 +1,182 @@ +# WARC fetcher performance: bottlenecks, analysis & recommendations + +Notes on optimizing the `cdxt repackage` fetch pipeline for large jobs (e.g. a +range-jobs file with ~1M rows), with a focus on running on EC2 in-region against +`s3://commoncrawl`. + +Status legend: **[done]** implemented on this branch · **[proposed]** not yet implemented. + +--- + +## How the fetcher works today + +Three async stages on a single `asyncio` event loop (`filter_warc/warc_filter.py`): + +1. **Produce** — the selected `RangeJobSource` yields `RangeJob`s (one per WARC record: + url + byte offset + length). The source is a *sync* generator run in a worker + thread (`asyncio.to_thread(drain)`, `warc_filter.py:~276`), pushing each job onto a + bounded `range_jobs_queue` via `run_coroutine_threadsafe(...).result()`. +2. **Read** — `num_readers` reader coroutines (`= --parallel`) each pull a job and do + **one S3 `GetObject` (or HTTP GET) per job** (`read_warc_records` → + `RangeJob.ranged_get_bytes`, `warc_filter.py:~446`, `data_classes.py:64`). The + payload (raw, already-gzipped WARC record bytes) is enqueued. +3. **Write** — `num_writers` (`= num_readers/6`) writer coroutines each own one output + shard and append records to it (`write_warc_records`, `warc_filter.py:~497`). + +Clients/config: + +- One shared `aioboto3` S3 read client and one write client (`get_aws_clients`, + `warc_filter.py:~179`); read pool sized `num_readers*3`. +- Retries/backoff for throttling (503 SlowDown) in `with_retries` (`s3_utils.py:27`). +- Output writers: `S3ShardWriter` (multipart upload, ≥5 MiB parts) or `LocalFileWriter` + (`aiofiles`); shard files keyed by `writer_id` + `sequence` + (`generate_warc_filename`, `warc_utils.py:67`). Records are written **verbatim** — no + recompression. + +--- + +## Key insight: concurrency vs. cores + +`aioboto3`/`asyncio` give **I/O concurrency**, not **multi-core parallelism**: + +- With `--parallel=30` there are up to 30 S3 requests *in flight at once*; while most + are blocked on the network, others run. For network-bound work this is exactly right. +- But asyncio is single-threaded: **one event loop, one OS thread, one CPU core.** All + the *CPU* work — TLS encrypt/decrypt, HTTP framing, aiobotocore object churn, byte + copies — is serialized on that one core. The readers **and** writers are all + coroutines on the same loop (`asyncio.create_task(...)`, `warc_filter.py:~321` and + `~334`). +- The GIL means *threads* wouldn't help CPU-bound work either. Only **multiple + processes** use multiple cores. + +Corollary: **sharded output ≠ multi-core.** Having N writers means N concurrent MPU +streams on one core, not N cores. Adding writers (or readers) past the point where the +single core saturates buys nothing but scheduling overhead and memory. + +--- + +## What actually limits a 1M-row job in-region + +**Bandwidth is rarely the constraint.** 1M records at ~30 KB compressed ≈ ~30 GB; over +even 5 Gbps that is ~50 s, at 25 Gbps ~10 s. A c5n NIC is nowhere near the bottleneck. + +**Request rate on a single core usually is.** 1M independent range GETs, each with +per-request TLS + HTTP + Python overhead, all on one event-loop core. A single loop +typically sustains on the order of a few thousand small GET/s before that core +saturates → ~3–8 min CPU-bound on one core, with the NIC mostly idle and the box's other +vCPUs unused. + +Two regimes, and the optimizations differ: + +- **Many small requests** (default): single-core CPU/TLS bound → cut request count + (coalescing), lift the single-core ceiling (uvloop), or use more cores + (multi-process). +- **Fewer, larger reads** (after coalescing, or large records): network-bound → one core + drives multiple Gbps fine; more processes buy little. + +### How to tell which regime you're in + +Run a sample and watch `htop`/`mpstat` during the fetch: + +- One core pinned ~100 %, NIC well below capacity → **single-core CPU-bound** → + coalescing / uvloop / multi-process. +- All cores low, throughput still short → **not CPU-bound**: the limiter is concurrency, + request latency, or S3 throttling — process/writer count is irrelevant. + +--- + +## Writer review (why more output shards won't fix a read-bound core) + +- **No recompression** — records arrive already gzipped and are written verbatim + (`writer.write(item.data)`, `warc_filter.py:~577`). The writer adds no gzip CPU. +- **`S3ShardWriter`** (`s3_writer.py:84`) buffers into a `bytearray`; on each ≥5 MiB + boundary it does `buffer[:n]` (copy), `del buffer[:n]` (an O(remaining) memmove), + `bytes(chunk)` (another copy), then `await upload_part`. Real but modest byte-copy cost. +- **`LocalFileWriter`** (`local_writer.py`) is `aiofiles` with an 8 KiB buffer and a + `flush()` per write — on a fast read stream the per-flush overhead + EBS limits are the + more likely bottleneck (prefer writing to `s3://` on a c5n). +- Writers are **few** (`num_readers/6`, e.g. 5) and batch into 5 MiB parts, so the CPU + hot spot is the **read side** (many small GETs + TLS), not the write side. More output + shards cannot relieve a read/TLS-bound core. + +The flip side: the shard model makes **multi-process essentially free on the output +side** — shards are independent files with no cross-shard state and no merge step, so N +worker processes can each write their own shards into the same prefix. + +--- + +## Recommendations (ranked) + +Assume parallelism is already configured (`--parallel` high) and reads go to +`s3://commoncrawl` in-region. + +1. **[done] Sort by `(warc_filename, warc_record_offset)`.** Groups records of the same + WARC file with ascending offsets → read locality, connection reuse, and the + prerequisite for coalescing and clean per-file process sharding. Implemented as an + `ORDER BY` on guided SQL queries (Athena/DuckDB) and an in-memory sort when loading a + CSV source; `--no-sort-ranges` opts out. Raw `--query` and `cdx` sources keep their + order. + +2. **[proposed] Coalesce adjacent ranges (highest remaining value).** After sorting, + merge nearby same-file ranges into one GET (fetch the superrange, slice locally by + offset). Safe because each WARC record is an independent gzip member, so concatenated + members remain valid and byte-exact slicing reconstructs each record. Use a gap + threshold (e.g. ≤256 KB–1 MB) to swallow the small request/metadata records between + indexed responses. Effect: collapses ~1M tiny latency/CPU-bound GETs into ~10–100k + larger sequential reads — directly attacks the request-rate ceiling; bandwidth easily + absorbs the gap bytes. Extreme case: a **density heuristic** — when a large fraction of + a ~1 GB file is needed, GET the whole file once and slice. + +3. **[proposed] Use all cores via multi-process sharding** — *only if a sample shows + single-core CPU-bound.* Split the sorted job file by `warc_filename` into N worker + processes, each with its own event loop, S3 client, and output shards. Multiplies the + request-rate ceiling ~linearly with cores. Coalescing (2) may keep you network-bound + and make this unnecessary — measure first. + +4. **[proposed] uvloop.** Cheap single-core lift for the request-rate-bound regime; + stacks under multi-process. + +5. **[proposed] Bound coalesced read size / watch RAM.** A c5n.xlarge has 10.5 GiB; + whole-file (~1 GB) reads × many readers will OOM. Cap superrange size and gap + threshold; revisit `warc_records_queue_size` (`warc_filter.py:~62`, default 200) since + each buffered payload is larger after coalescing. + +6. **[proposed] Write output to `s3://` in-region, not local EBS.** c5n.xlarge is + EBS-only; gp3 baseline (~125 MB/s) can bottleneck writers at multi-Gbps read rates. + The `S3ShardWriter` MPU path keeps everything in-region and off EBS. Also revisit the + `writers = readers/6` heuristic for in-region fast reads. + +7. **[proposed] Instance selection follows the constraint.** Because the binding limit is + request-rate/CPU (not bandwidth), **more/faster cores beat more Gbps**. c5n.xlarge's + "up to 25 Gbps" is burstable and depletes to baseline on long jobs anyway — but that + barely matters when 30 GB is seconds of transfer. Scale vCPUs to feed (3); reserve + sustained-25 Gbps types for genuinely bandwidth-bound (huge-record) jobs. + +8. **[proposed] Dedup identical `(filename, offset, length)` rows** before fetching — + free; fold into the sort/coalesce pass. + +### Latent issues worth noting + +- **HTTP path blocks the event loop.** `ranged_get_bytes`'s HTTP branch calls the + *synchronous* `myrequests_get(...)` without `await`/`to_thread` (`data_classes.py:~91`). + Since the default `--warc-download-prefix` is `https://data.commoncrawl.org`, the + default fetch path serializes the "parallel" readers. On EC2, prefer `s3://commoncrawl` + (the S3 branch is properly async); otherwise the HTTP fetch should be offloaded to a + thread pool / async HTTP client. **[proposed]** +- **Per-row cross-thread enqueue.** The producer does one + `run_coroutine_threadsafe(...).result()` per job (`warc_filter.py:~268`) — 1M + loop↔thread handoffs. Batching enqueues removes a real per-row overhead at this scale. + **[proposed]** + +--- + +## TL;DR + +1. **[done]** Sort by `(warc_filename, offset)` — locality + enables coalescing. +2. **[proposed]** Coalesce adjacent ranges — the big request-rate win in-region. +3. **[proposed]** Multi-process by filename — only after confirming single-core CPU-bound; + the output shard model already supports it. +4. **[proposed]** uvloop, bounded read size, `s3://` output, dedup — supporting wins. + +Always **measure the regime first** (`htop` during a sample run): in-region the limiter +is almost always request-rate on one core, not bandwidth. From a15b9bfb5211141439f500a9e4cf8461028f7173 Mon Sep 17 00:00:00 2001 From: malteos Date: Sat, 20 Jun 2026 10:55:42 +0000 Subject: [PATCH 6/6] feat(repackage): multi-process fetching with single merged WARC output Add `cdxt repackage --processes N`: shard range jobs by warc_filename across N worker processes (one asyncio event loop per CPU core), then merge the per-process shards into a single .warc.gz with one warcinfo record (server-side S3 UploadPartCopy, or fsspec streaming locally). Reads from s3://commoncrawl in-region. The work is many small independent range-GETs; the limiter is request-rate on one core. Multi-process gives ~2.6x on a 4-vCPU c5n.xlarge (~457 -> ~1130 rec/s). - one writer per process; removed the multi-writer-per-process fan-out (num_writers / fetcher_to_consumer_ratio / --parallel_writers) and the now-vestigial writer_id segment from output filenames - only the first shard writes the warcinfo; all shards share one canonical WARC-Record-ID and the warcinfo filename names the merged file - optional uvloop event loop via CDXT_UVLOOP=1 (~+8% single-core) - new modules: filter_warc/multiprocess.py, filter_warc/merge.py - docs: README repackage section, CHANGELOG, docs/notes/warc-fetcher-performance.md - tests updated for the new filename scheme and single-writer model Claude-Session: https://claude.ai/code/session_011XtVGuu26tiQpAbdPhnixk --- CHANGELOG.md | 7 + README.md | 98 ++++++------ cdx_toolkit/filter_warc/args.py | 27 +++- cdx_toolkit/filter_warc/command.py | 28 ++++ cdx_toolkit/filter_warc/merge.py | 136 ++++++++++++++++ cdx_toolkit/filter_warc/multiprocess.py | 163 ++++++++++++++++++++ cdx_toolkit/filter_warc/warc_filter.py | 112 +++++++------- cdx_toolkit/filter_warc/warc_utils.py | 44 ++++-- docs/notes/warc-fetcher-performance.md | 50 +++--- tests/filter_warc/test_command.py | 12 +- tests/filter_warc/test_sql_sources_gated.py | 2 +- tests/filter_warc/test_warc_filter.py | 18 +-- 12 files changed, 533 insertions(+), 164 deletions(-) create mode 100644 cdx_toolkit/filter_warc/merge.py create mode 100644 cdx_toolkit/filter_warc/multiprocess.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9cfe21c..c0e3f54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +- unreleased (feat/warc-range-sources) + + new `repackage` command: extract a WARC subset from pluggable sources (cdx / sql{athena,duckdb} / csv), reading source WARCs over HTTP or S3 and writing verbatim + + `--processes N` multi-process fetching (one event loop per core); range jobs are sharded by WARC filename and the shards are merged into a single `.warc.gz` with one warcinfo record (server-side on S3 / streamed locally). `--keep-shards` keeps the shards + + one writer per process (removed the multi-writer-per-process fan-out); `--parallel_readers` sets async readers per process + + optional uvloop event loop via `CDXT_UVLOOP=1` + + see docs/notes/warc-fetcher-performance.md for benchmarks and tuning + - 0.9.38 + deprecated support for py3.7 and py.3.8 + added support for py3.13 and py3.14 diff --git a/README.md b/README.md index 970c3ad..e53aa1e 100644 --- a/README.md +++ b/README.md @@ -305,68 +305,68 @@ Filtering throughput depends on your machine. For reference, on an AWS EC2 c5n.xlarge instance filtering all 300 CDX files from CC-MAIN-2024-30 takes ~1.4 hours with 100k URLs in the whitelist. -## WARC extraction using CDX files +## WARC extraction / repackaging -You can extract parts of WARC files using the cdxt command line script. -The WARC extraction can read CDX files from local and remote file -systems, like S3 buckets. Multiple CDX files can be defined -using a glob pattern. For downloading WARC parts from HTTP or S3, you can -define the download prefix, e.g., `s3://commoncrawl` for S3 download. +The `cdxt repackage` command extracts a subset of WARC records (selected by a +columnar/CDX index) and writes them into one or more new WARC files. Records are +copied **verbatim** (no recompression). It can read the source WARCs over HTTP or, +much faster in-region, directly from S3 (`--warc-download-prefix=s3://commoncrawl`), +and write the output to the local filesystem or S3. -``` -$ cdxt -v --cc warc_by_cdx \ - [--cdx-glob ] \ - --prefix \ - --warc-download-prefix= \ - --creator \ - --operator \ - [--implementation ] - [--write-paths-as-resource-records ] - [--write-paths-as-resource-records-metadata ] -``` +The records to extract come from a pluggable source, chosen with `--target-source`: -By default, we use a [fsspec](https://filesystem-spec.readthedocs.io/en/latest/index.html) -implementation to write and read to local or remote file systems. -For better throughput for S3 read/write, we have also a specific implementation -using [aioboto3](https://github.com/terricain/aioboto3) that you can enable with -the `--implementation=aioboto3` argument. With aioboto3, we achieved ~ 80 requests / second -on an AWS EC2 c5n.xlarge instance. +- `cdx` — CDX index file(s) (`--cdx-path`, `--cdx-glob`); +- `sql` — the CC columnar index via `--engine athena` or `--engine duckdb` + (filter with `--hostnames` / `--domains`, or raw `--query` / `--query-file`); +- `csv` — a range-jobs CSV (`--csv-path`) of `warc_filename,warc_record_offset,warc_record_length`. -You can add one or multiple files with metadata as resource records to -the extracted WARC. For instance, this is useful to maintain the CDX filter -inputs, e.g., the whitelist list. To do this, you need to provide the -corresponding file paths as arguments `--write-paths-as-resource-records=s3:///my-s3-bucket/path/to/my-url-whitelist.txt` -and `--write-paths-as-resource-records-metadata=s3:///my-s3-bucket/path/to/metadata.json`. -The metadata file is optional and can have the following optional fields: +You can materialize the selected ranges to a CSV without fetching +(`--range-jobs-output FILE --no-fetch`) and fetch them later with `--target-source csv`. -```json -{ - "warc_content_type": "str", - "uri": "str", - "http_headers": {"k": "v"}, - "warc_headers_dict": {"k": "v"} -} -``` +### Multi-core fetching and a single output file -This in one example for a metadata JSON file: +The fetcher is asyncio-based: many concurrent range reads, but a single event loop +runs on one CPU core. For large jobs in-region the limiter is request-rate on that +core. Use `--processes N` (set it to the vCPU count) to run N worker processes, each +with its own event loop; `--parallel_readers R` sets the async readers per process. + +The range jobs are sharded by WARC filename across the processes, each process writes +one shard, and the shards are **merged into a single `.warc.gz`** (server-side +on S3, or streamed locally) with a single `warcinfo` record. Pass `--keep-shards` to +keep the per-process shards instead. -```json -{ - "uri": "filter_cdx.gz", - "warc_content_type": "application/cdx", -} ``` +$ CDXT_UVLOOP=1 cdxt -v repackage \ + --target-source csv --csv-path range-jobs.csv \ + --prefix s3://my-bucket/path/homepages \ + --warc-download-prefix=s3://commoncrawl \ + --processes 4 --parallel_readers 48 \ + --creator "..." --operator "..." --is-part-of "CC-MAIN-2026-21" +# -> writes a single s3://my-bucket/path/homepages.warc.gz +``` + +Setting `CDXT_UVLOOP=1` uses the [uvloop](https://github.com/MagicStack/uvloop) event +loop (install `uvloop` separately) for a small single-core speedup. + +On an AWS EC2 c5n.xlarge (4 vCPU) in us-east-1, reading from `s3://commoncrawl`, a +~1 GiB / ~35k-record homepages job runs at ~1100 records/s (~450/s per core, scaling +~linearly with `--processes`). See +[docs/notes/warc-fetcher-performance.md](docs/notes/warc-fetcher-performance.md) for the +full analysis and tuning guidance. + +### Metadata records -The full WARC extraction command could look like this: +You can add one or more files as WARC `metadata` records at the top of the output +(after the `warcinfo` record) with `--write-paths-as-metadata-records`. This is useful +to carry, e.g., the filter/whitelist inputs alongside the extracted records: ``` -$ cdxt -v --cc warc_by_cdx \ - s3://my-s3-bucket/filtered-cdxs --cdx-glob "*.gz" \ - --prefix /local/path/filtered-warcs/ \ +$ cdxt -v repackage \ + --target-source cdx s3://my-bucket/filtered-cdxs --cdx-glob "*.gz" \ + --prefix /local/path/filtered-warcs \ --warc-download-prefix=s3://commoncrawl \ --creator foo --operator bob \ - --write-paths-as-resource-records=s3:///my-s3-bucket/path/to/my-url-whitelist.txt \ - --write-paths-as-resource-records-metadata=s3:///my-s3-bucket/path/to/metadata.json + --write-paths-as-metadata-records /path/to/url-whitelist.txt /path/to/metadata.json ``` ## TODO diff --git a/cdx_toolkit/filter_warc/args.py b/cdx_toolkit/filter_warc/args.py index b40d731..c2872ce 100644 --- a/cdx_toolkit/filter_warc/args.py +++ b/cdx_toolkit/filter_warc/args.py @@ -158,22 +158,35 @@ def add_repackage_args(parser: argparse.ArgumentParser): help='Paths to multiple files. File content is written to as a metadata record to each the WARC file', ) parser.add_argument( - '--parallel', + '--processes', type=int, default=1, - help='Number of parallel workers for reading and writing WARC records (default: 1, sequential processing)', + help=('Number of worker processes for fetching (default: 1). A single asyncio loop ' + 'saturates one CPU core on many small range reads; set this to the vCPU count ' + 'to use all cores. The range jobs are sharded by WARC filename across processes ' + 'and the per-process output shards are merged into a single .warc.gz ' + '(one warcinfo record). Each process uses --parallel_readers async readers.'), ) parser.add_argument( - '--parallel_readers', + '--keep-shards', + action='store_true', + help='In multi-process mode, keep the intermediate per-process shard WARCs instead of ' + 'deleting them after the merge.', + ) + parser.add_argument( + '--parallel', type=int, - default=None, - help='Number of parallel workers for reading WARC records (default: same as `parallel`)', + default=1, + help='Number of async readers per process for fetching WARC records (default: 1). Each ' + 'process has a single writer; combine with --processes to use multiple cores.', ) parser.add_argument( - '--parallel_writers', + '--parallel_readers', type=int, default=None, - help='Number of parallel workers for writing WARC records (default: same as `parallel`)', + help='Number of async readers per process for reading WARC records (default: same as ' + '`parallel`). Each process has a single writer; use --processes for multi-core ' + 'scaling and output sharding.', ) parser.add_argument( '--log_every_n', diff --git a/cdx_toolkit/filter_warc/command.py b/cdx_toolkit/filter_warc/command.py index 8a2d187..6350780 100644 --- a/cdx_toolkit/filter_warc/command.py +++ b/cdx_toolkit/filter_warc/command.py @@ -6,6 +6,7 @@ import fsspec +import os import sys import time import logging @@ -101,6 +102,32 @@ def run_repackage(args, cmdline): prefix_fs, prefix_fs_path = fsspec.url_to_fs(prefix_path) prefix_fs.makedirs(prefix_fs._parent(prefix_fs_path), exist_ok=True) + # Multi-process mode: shard the range jobs across processes (one event loop per core) + # and merge the per-process shards into a single .warc.gz. + if getattr(args, 'processes', 1) and args.processes > 1 and not args.no_fetch: + from cdx_toolkit.filter_warc.multiprocess import run_multiprocess_repackage + + readers = args.parallel_readers if args.parallel_readers is not None else args.parallel + records_n = run_multiprocess_repackage( + source=source, + n_processes=args.processes, + readers_per_process=readers, + prefix_path=prefix_path, + writer_info=info, + writer_subprefix=args.subprefix, + write_paths_as_metadata_records=write_paths_as_metadata_records, + log_every_n=log_every_n, + aws_region_name='us-east-1', + max_attempts=5, + record_limit=limit, + uvloop=os.environ.get('CDXT_UVLOOP') == '1', + warc_download_prefix=args.warc_download_prefix, + keep_shards=args.keep_shards, + ) + logger.info('WARC records extracted: %i', records_n) + logger.info('Script execution time: %.3f seconds', time.time() - start_time) + return + warc_filter = WARCFilter( source=source, range_jobs_output=args.range_jobs_output, @@ -114,6 +141,7 @@ def run_repackage(args, cmdline): log_every_n=log_every_n, warc_download_prefix=args.warc_download_prefix, n_parallel=n_parallel, + n_parallel_readers=args.parallel_readers, max_file_size=args.size, ) records_n = warc_filter.filter() diff --git a/cdx_toolkit/filter_warc/merge.py b/cdx_toolkit/filter_warc/merge.py new file mode 100644 index 0000000..9113b1c --- /dev/null +++ b/cdx_toolkit/filter_warc/merge.py @@ -0,0 +1,136 @@ +"""Merge several WARC shard objects into one, preserving order. + +A WARC file is a concatenation of independent gzip members, so byte-concatenating +shard ``*.warc.gz`` files in order yields a single valid multi-member ``.warc.gz``. +For S3 destinations this is done server-side with multipart ``UploadPartCopy`` (the +bytes never leave S3); for local destinations the shard files are streamed together. + +The caller is responsible for ensuring exactly one shard carries the ``warcinfo`` +record (the first one) so the merged file has a single canonical warcinfo. +""" +import logging +import shutil + +import boto3 +import fsspec + +from cdx_toolkit.filter_warc.s3_utils import is_s3_url, parse_s3_uri + +logger = logging.getLogger(__name__) + +# S3 multipart copy requires every part except the last to be >= 5 MiB. +_MIN_PART = 5 * 1024 * 1024 + + +def merge_objects(dest: str, sources: list, aws_region_name: str = 'us-east-1') -> int: + """Concatenate ``sources`` (in the given order) into ``dest``. Returns dest size. + + - all-S3 (dest and every source on S3): merged server-side via multipart + ``UploadPartCopy`` -- the bytes never leave S3 (fast, in-region). + - otherwise: streamed via fsspec, which supports the local filesystem and any + fsspec-writable backend. (Note: repackage output itself is currently limited to + S3 and local by the WARC writer layer, so in practice sources are S3 or local.) + """ + if is_s3_url(dest) and all(is_s3_url(s) for s in sources): + return _merge_s3(dest, sources, aws_region_name) + return _merge_fsspec(dest, sources) + + +def _merge_s3(dest: str, sources: list, aws_region_name: str) -> int: + s3 = boto3.client('s3', region_name=aws_region_name) + dest_bucket, dest_key = parse_s3_uri(dest) + + # Verify every non-final part is large enough for UploadPartCopy; if any small + # shard would violate the 5 MiB rule, fall back to a download+reupload merge. + heads = [s3.head_object(Bucket=b, Key=k) for b, k in (parse_s3_uri(s) for s in sources)] + sizes = [h['ContentLength'] for h in heads] + if any(sz < _MIN_PART for sz in sizes[:-1]): + logger.info('A shard is < 5 MiB; merging via download+reupload instead of UploadPartCopy') + return _merge_s3_streaming(s3, dest_bucket, dest_key, sources) + + mpu = s3.create_multipart_upload(Bucket=dest_bucket, Key=dest_key) + upload_id = mpu['UploadId'] + parts = [] + try: + for i, src in enumerate(sources, start=1): + sb, sk = parse_s3_uri(src) + resp = s3.upload_part_copy( + Bucket=dest_bucket, Key=dest_key, UploadId=upload_id, + PartNumber=i, CopySource={'Bucket': sb, 'Key': sk}, + ) + parts.append({'PartNumber': i, 'ETag': resp['CopyPartResult']['ETag']}) + logger.info('Merged part %d/%d: %s', i, len(sources), src) + s3.complete_multipart_upload( + Bucket=dest_bucket, Key=dest_key, UploadId=upload_id, + MultipartUpload={'Parts': parts}, + ) + except Exception: + logger.exception('Merge failed; aborting multipart upload for %s', dest) + s3.abort_multipart_upload(Bucket=dest_bucket, Key=dest_key, UploadId=upload_id) + raise + size = s3.head_object(Bucket=dest_bucket, Key=dest_key)['ContentLength'] + logger.info('Merged %d shards -> %s (%d bytes)', len(sources), dest, size) + return size + + +def _merge_s3_streaming(s3, dest_bucket, dest_key, sources) -> int: + mpu = s3.create_multipart_upload(Bucket=dest_bucket, Key=dest_key) + upload_id = mpu['UploadId'] + parts = [] + buf = bytearray() + part_no = 1 + try: + def flush(final=False): + nonlocal buf, part_no + while len(buf) >= _MIN_PART or (final and buf): + take = len(buf) if final else _MIN_PART + chunk = bytes(buf[:take]) + del buf[:take] + r = s3.upload_part(Bucket=dest_bucket, Key=dest_key, UploadId=upload_id, + PartNumber=part_no, Body=chunk) + parts.append({'PartNumber': part_no, 'ETag': r['ETag']}) + part_no += 1 + if final and not buf: + break + for src in sources: + sb, sk = parse_s3_uri(src) + body = s3.get_object(Bucket=sb, Key=sk)['Body'].read() + buf.extend(body) + flush() + flush(final=True) + s3.complete_multipart_upload(Bucket=dest_bucket, Key=dest_key, UploadId=upload_id, + MultipartUpload={'Parts': parts}) + except Exception: + s3.abort_multipart_upload(Bucket=dest_bucket, Key=dest_key, UploadId=upload_id) + raise + return s3.head_object(Bucket=dest_bucket, Key=dest_key)['ContentLength'] + + +def _merge_fsspec(dest: str, sources: list) -> int: + """Stream sources into dest via fsspec (local filesystem or any fsspec backend).""" + out_fs, _ = fsspec.core.url_to_fs(dest) + with fsspec.open(dest, 'wb') as out: + for src in sources: + with fsspec.open(src, 'rb') as fh: + shutil.copyfileobj(fh, out, length=8 * 1024 * 1024) + logger.info('Merged %s', src) + size = out_fs.size(dest) + logger.info('Merged %d shards -> %s (%s bytes)', len(sources), dest, size) + return size or 0 + + +def delete_objects(objs: list, aws_region_name: str = 'us-east-1') -> None: + """Delete shard objects after a successful merge (S3, local, or any fsspec backend).""" + s3 = None + for o in objs: + if is_s3_url(o): + if s3 is None: + s3 = boto3.client('s3', region_name=aws_region_name) + b, k = parse_s3_uri(o) + s3.delete_object(Bucket=b, Key=k) + else: + fs, path = fsspec.core.url_to_fs(o) + try: + fs.rm(path) + except FileNotFoundError: + pass diff --git a/cdx_toolkit/filter_warc/multiprocess.py b/cdx_toolkit/filter_warc/multiprocess.py new file mode 100644 index 0000000..ccf26d3 --- /dev/null +++ b/cdx_toolkit/filter_warc/multiprocess.py @@ -0,0 +1,163 @@ +"""Multi-process repackage: shard the range jobs, fetch each shard in its own process +(one asyncio event loop per CPU core), then merge the shards into a single WARC. + +A single asyncio loop saturates one CPU core on many small range-GETs, so reaching +all cores needs multiple processes. This orchestrator keeps that entirely inside +``cdxt repackage`` (no external driver): + +1. Drain the configured source once and split it into N self-contained shard CSVs, + hashed by WARC filename so every record of a file stays in one shard. +2. Run one worker process per shard concurrently. Each writes exactly one output WARC + (``-shNN-000000-001.warc.gz``). Only shard 0 writes the ``warcinfo`` record; + all shards share one canonical ``WARC-Record-ID`` and the warcinfo ``filename`` field + names the final merged file. +3. Merge the shard objects in order into ``.warc.gz`` (server-side on S3), then + delete the shards. +""" +import hashlib +import logging +import os +import shutil +import tempfile +import time +import uuid +from concurrent.futures import ProcessPoolExecutor +from typing import Dict, List, Optional + +from cdx_toolkit.filter_warc.sources.csv import CsvSource, RangeJobCsvWriter +from cdx_toolkit.filter_warc.warc_filter import WARCFilter +from cdx_toolkit.filter_warc.warc_utils import generate_warc_filename +from cdx_toolkit.filter_warc.s3_utils import is_s3_url, parse_s3_uri +from cdx_toolkit.filter_warc import merge as merge_mod + +logger = logging.getLogger(__name__) + + +def _shard_filename(prefix_path: str, subprefix: str, gzip: bool = True) -> str: + """Object path a worker writes (sequence 1, no rotation).""" + if is_s3_url(prefix_path): + bucket, key_prefix = parse_s3_uri(prefix_path) + key = generate_warc_filename(key_prefix, sequence=1, writer_subprefix=subprefix, gzip=gzip) + return f's3://{bucket}/{key}' + return generate_warc_filename(prefix_path, sequence=1, writer_subprefix=subprefix, gzip=gzip) + + +def _run_worker(cfg: Dict) -> int: + """Worker entrypoint (runs in a child process): fetch one shard into one WARC.""" + if cfg.get('uvloop'): + os.environ['CDXT_UVLOOP'] = '1' + source = CsvSource(cfg['shard_csv'], warc_download_prefix=None, sort=True) + wf = WARCFilter( + source=source, + prefix_path=cfg['prefix_path'], + writer_info=cfg['writer_info'], + writer_subprefix=cfg['subprefix'], + write_paths_as_metadata_records=cfg['write_paths_as_metadata_records'], + log_every_n=cfg['log_every_n'], + # Shard CSV is self-contained (full warc_url), so CsvSource ignores this; it is + # still needed so WARCFilter.needs_aws() knows reads come from S3. + warc_download_prefix=cfg['warc_download_prefix'], + n_parallel_readers=cfg['readers'], + aws_region_name=cfg['aws_region_name'], + max_attempts=cfg['max_attempts'], + max_file_size=None, # never rotate: exactly one file per worker + warcinfo_record_id=cfg['warcinfo_record_id'], + warcinfo_filename=cfg['warcinfo_filename'], + write_warcinfo=cfg['write_warcinfo'], + ) + return wf.filter() + + +def _split_source_to_shards(source, n: int, out_dir: str, record_limit: int = 0) -> List[str]: + """Drain the source once into N self-contained shard CSVs, hashed by WARC file.""" + paths = [os.path.join(out_dir, f'shard.{i}.csv') for i in range(n)] + writers = [RangeJobCsvWriter(p, self_contained=True) for p in paths] + counts = [0] * n + total = 0 + try: + for job in source.iter_range_jobs(): + key = (job.filename or job.url or '').encode() + idx = int(hashlib.md5(key).hexdigest(), 16) % n + writers[idx].write(job) + counts[idx] += 1 + total += 1 + if record_limit and total >= record_limit: + break + finally: + for w in writers: + w.close() + logger.info('Sharded %d range jobs into %d shards: %s', total, n, counts) + return paths + + +def run_multiprocess_repackage( + *, + source, + n_processes: int, + readers_per_process: int, + prefix_path: str, + writer_info: Dict, + writer_subprefix: Optional[str], + write_paths_as_metadata_records: Optional[List[str]], + log_every_n: int, + aws_region_name: str, + max_attempts: int, + record_limit: int, + uvloop: bool, + warc_download_prefix: Optional[str] = None, + keep_shards: bool = False, +) -> int: + """Shard -> N worker processes -> merge into a single .warc.gz. Returns count.""" + if writer_subprefix: + logger.warning('--subprefix is ignored in multi-process mode (shards use shNN)') + + final_dest = prefix_path + '.warc.gz' + final_name = (parse_s3_uri(final_dest)[1] if is_s3_url(final_dest) else final_dest).split('/')[-1] + warcinfo_id = f'' + + tmp_dir = tempfile.mkdtemp(prefix='cdxt_shards_') + start = time.time() + try: + shard_csvs = _split_source_to_shards(source, n_processes, tmp_dir, record_limit) + + configs = [] + shard_outputs = [] + for i in range(n_processes): + subprefix = f'sh{i:02d}' + shard_outputs.append(_shard_filename(prefix_path, subprefix)) + configs.append(dict( + shard_csv=shard_csvs[i], + prefix_path=prefix_path, + subprefix=subprefix, + writer_info=writer_info, + write_paths_as_metadata_records=write_paths_as_metadata_records, + log_every_n=log_every_n, + readers=readers_per_process, + aws_region_name=aws_region_name, + max_attempts=max_attempts, + warcinfo_record_id=warcinfo_id, + warcinfo_filename=final_name, + write_warcinfo=(i == 0), # only the first shard carries the warcinfo + uvloop=uvloop, + warc_download_prefix=warc_download_prefix, + )) + + logger.info('Launching %d worker processes x %d readers', n_processes, readers_per_process) + with ProcessPoolExecutor(max_workers=n_processes) as ex: + results = list(ex.map(_run_worker, configs)) + total = sum(r for r in results if r and r > 0) + fetch_elapsed = time.time() - start + logger.info('All %d workers done: %d records in %.1fs (%.0f rec/s)', + n_processes, total, fetch_elapsed, total / fetch_elapsed if fetch_elapsed else 0) + + logger.info('Merging %d shards -> %s', n_processes, final_dest) + merge_mod.merge_objects(final_dest, shard_outputs, aws_region_name=aws_region_name) + + if not keep_shards: + merge_mod.delete_objects(shard_outputs, aws_region_name=aws_region_name) + logger.info('Deleted %d intermediate shards', n_processes) + + logger.info('Repackage complete -> %s (%.1fs total)', final_dest, time.time() - start) + return total + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) diff --git a/cdx_toolkit/filter_warc/warc_filter.py b/cdx_toolkit/filter_warc/warc_filter.py index 8399286..b6afa33 100644 --- a/cdx_toolkit/filter_warc/warc_filter.py +++ b/cdx_toolkit/filter_warc/warc_filter.py @@ -1,5 +1,6 @@ import asyncio import logging +import os import statistics import sys from typing import List, Optional, Dict @@ -54,18 +55,19 @@ def __init__( warc_download_prefix: Optional[str] = None, n_parallel: int = 1, n_parallel_readers: Optional[int] = None, - n_parallel_writers: Optional[int] = None, max_attempts: int = 5, base_backoff_seconds: float = 0.5, # writer_kwargs: Optional[Dict] = None, range_jobs_queue_size: int = 1000, warc_records_queue_size: int = 200, - fetcher_to_consumer_ratio: int = 6, aws_region_name: str = 'us-east-1', warc_version: str = '1.0', content_type: Optional[str] = None, min_part_size: int = 5 * 1024 * 1024, # 5 MiB (for upload) max_file_size: Optional[int] = 1 * 1024 * 1024 * 1024, # 1 GiB (for WARC outputs) + warcinfo_record_id: Optional[str] = None, + warcinfo_filename: Optional[str] = None, + write_warcinfo: bool = True, ): """Initialize the WARC filter. @@ -84,15 +86,13 @@ def __init__( record_limit: Maximum number of records to process (0 for unlimited). log_every_n: Log progress every N records. warc_download_prefix: Optional prefix to prepend to WARC URLs. - n_parallel: Number of parallel workers (default for readers/writers). - n_parallel_readers: Number of parallel reader tasks (overrides n_parallel). - n_parallel_writers: Number of parallel writer tasks (overrides n_parallel). + n_parallel: Number of async readers per process (one writer per process). + n_parallel_readers: Number of async reader tasks (overrides n_parallel). max_attempts: Maximum retry attempts for failed operations. base_backoff_seconds: Base backoff time in seconds for retries. writer_kwargs: Optional additional kwargs for writers. range_jobs_queue_size: Maximum size of range jobs queue. warc_records_queue_size: Maximum size of WARC records queue. - fetcher_to_consumer_ratio: Ratio of readers to writers for auto-scaling. aws_region_name: AWS region name for S3 operations. warc_version: WARC format version (e.g., '1.0' or '1.1'). content_type: Optional content type for WARC output. @@ -115,17 +115,15 @@ def __init__( self.range_jobs_queue_size = range_jobs_queue_size self.warc_records_queue_size = warc_records_queue_size self.aws_region_name = aws_region_name - self.fetcher_to_consumer_ratio = fetcher_to_consumer_ratio self.max_attempts = max_attempts self.base_backoff_seconds = base_backoff_seconds + # Many async readers feed a single writer. Writing is a cheap verbatim byte copy + # (+ >=5 MiB MPU parts) that one coroutine sustains easily; multi-core scaling and + # output sharding are handled by running multiple *processes* (see + # filter_warc/multiprocess.py), not multiple writers on one event loop. self.n_parallel = n_parallel self.num_readers = n_parallel_readers if n_parallel_readers is not None else n_parallel - self.num_writers = ( - n_parallel_writers - if n_parallel_writers is not None - else max(int(self.num_readers / self.fetcher_to_consumer_ratio), 1) - ) self.gzip = True @@ -133,6 +131,9 @@ def __init__( self.content_type = content_type self.min_part_size = min_part_size self.max_file_size = max_file_size + self.warcinfo_record_id = warcinfo_record_id + self.warcinfo_filename = warcinfo_filename + self.write_warcinfo = write_warcinfo def filter(self) -> int: """Perform the filtering process (calls async method via asyncio.run). @@ -140,8 +141,17 @@ def filter(self) -> int: Returns: int: Number of records written, or -1 if interrupted. """ + runner = asyncio.run + if os.environ.get('CDXT_UVLOOP') == '1': + try: + import uvloop + + runner = uvloop.run + logger.info('Using uvloop event loop') + except ImportError: + logger.warning('CDXT_UVLOOP=1 but uvloop is not installed; using default asyncio loop') try: - return asyncio.run(self.filter_async()) + return runner(self.filter_async()) except KeyboardInterrupt: logger.warning('Interrupted by user.') @@ -202,9 +212,9 @@ async def get_aws_clients(self) -> Optional[Dict]: **self.get_boto3_base_config(), ) - # Optimized config for multipart uploads + # Optimized config for multipart uploads (single writer) write_config = Config( - max_pool_connections=self.num_writers * 4, + max_pool_connections=8, read_timeout=120, connect_timeout=10, **self.get_boto3_base_config(), @@ -311,7 +321,7 @@ async def _run_filter_pipeline( Returns: int: Number of records written. """ - logger.info('Starting job generator, %d WARC readers, %d WARC writers', self.num_readers, self.num_writers) + logger.info('Starting job generator, %d WARC readers, 1 WARC writer', self.num_readers) # Generate range jobs from the configured source (bridged sync->async in a thread). csv_writer = self._make_csv_writer() @@ -330,27 +340,23 @@ async def _run_filter_pipeline( for i in range(self.num_readers) ] - # Write WARC records - warc_writers = [ - asyncio.create_task( - self.write_warc_records( - writer_id=i, - warc_records_queue=warc_records_queue, - s3_client=write_s3_client, - ) + # Write WARC records (a single writer owns the output shard for this process) + warc_writer = asyncio.create_task( + self.write_warc_records( + warc_records_queue=warc_records_queue, + s3_client=write_s3_client, ) - for i in range(self.num_writers) - ] + ) # Start writer coordination task writer_coordinator = asyncio.create_task(self._coordinate_writer_shutdown(warc_readers, warc_records_queue)) await job_generators - logger.info('Range jobs submitted, monitoring readers and writers') + logger.info('Range jobs submitted, monitoring readers and writer') # Wait for all tasks to complete readers_results = await asyncio.gather(*warc_readers) - writers_results = await asyncio.gather(*warc_writers) + writer_result = await warc_writer await writer_coordinator readers_records = sum([result['stats']['total_records'] for result in readers_results]) @@ -364,18 +370,14 @@ async def _run_filter_pipeline( logger.info(f'All WARC readers completed: {readers_records} records') logger.info(f'Total reader throughput: {readers_mb_per_sec:.2f} MB/s; {readers_records_per_sec:.2f} rec/s') - writers_records = sum([result['stats']['total_records'] for result in writers_results]) - writers_mb_per_sec = self.num_writers * statistics.mean( - [result['stats']['mb_per_sec'] for result in writers_results] + writer_stats = writer_result['stats'] + logger.info(f"WARC writer completed: {writer_stats['total_records']} records") + logger.info( + f"Total writer throughput: {writer_stats['mb_per_sec']:.2f} MB/s; " + f"{writer_stats['records_per_sec']:.2f} rec/s" ) - writers_records_per_sec = self.num_writers * statistics.mean( - [result['stats']['records_per_sec'] for result in writers_results] - ) - - logger.info(f'All WARC writers completed: {writers_records} records') - logger.info(f'Total writer throughput: {writers_mb_per_sec:.2f} MB/s; {writers_records_per_sec:.2f} rec/s') - return writers_records + return writer_stats['total_records'] async def _coordinate_writer_shutdown(self, warc_readers: List[asyncio.Task], warc_records_queue: asyncio.Queue): """Coordinate efficient shutdown of writers as readers complete. @@ -398,12 +400,9 @@ async def _coordinate_writer_shutdown(self, warc_readers: List[asyncio.Task], wa completed_readers = len(warc_readers) - len(pending) logger.debug(f'Readers completed: {completed_readers}/{len(warc_readers)}') - # All readers completed - signal writers to stop - logger.info('All readers completed, signaling writers to stop') - - # Send stop signals to all writers - for _ in range(self.num_writers): - await warc_records_queue.put(_STOP) + # All readers completed - signal the writer to stop + logger.info('All readers completed, signaling writer to stop') + await warc_records_queue.put(_STOP) async def read_warc_records( self, @@ -496,19 +495,18 @@ async def write_metadata_records(self, writer, warcinfo_id: str) -> int: async def write_warc_records( self, - writer_id: int, warc_records_queue: asyncio.Queue, s3_client=None, ) -> dict: - """Write WARC records. Each writer owns ONE shard MPU and appends ranges to it. + """Write WARC records. The single writer owns the output WARC (one shard per + process) and appends ranges to it, rotating by --size when set. Args: - writer_id: Unique identifier for this writer task. warc_records_queue: Queue to read RangePayload objects from. s3_client: Optional S3 client for writing WARC files to S3. Returns: - dict: Statistics dictionary with writer_id and throughput stats. + dict: Statistics dictionary with throughput stats. """ # File rotation tracking current_file_sequence = 1 @@ -516,7 +514,6 @@ async def write_warc_records( new_writer_kwargs = dict( s3_client=s3_client, - writer_id=writer_id, output_path_prefix=self.prefix_path, max_attempts=self.max_attempts, base_backoff_seconds=self.base_backoff_seconds, @@ -526,6 +523,9 @@ async def write_warc_records( gzip=self.gzip, content_type=self.content_type, min_part_size=self.min_part_size, + warcinfo_record_id=self.warcinfo_record_id, + warcinfo_filename=self.warcinfo_filename, + write_warcinfo=self.write_warcinfo, ) # Initialize first writer with header @@ -552,8 +552,7 @@ async def write_warc_records( if item is _STOP: stats = tracker.get_stats() logger.info( - 'WARC writer %d stopping. Stats: %.1fs, %d items, %.1f MB written, %.2f MB/s write speed', - writer_id, + 'WARC writer stopping. Stats: %.1fs, %d items, %.1f MB written, %.2f MB/s write speed', stats['elapsed'], stats['total_requests'], stats['total_bytes'] / (1024 * 1024), @@ -579,10 +578,10 @@ async def write_warc_records( tracker.add(bytes_count=len(item.data), records_count=item.job.records_count) # Log progress every N items - self.log_writer(writer_id=writer_id, counter=counter, tracker=tracker) + self.log_writer(counter=counter, tracker=tracker) except Exception: - logger.exception('WARC writer %d failed on %s', writer_id, getattr(item, 'job', None)) + logger.exception('WARC writer failed on %s', getattr(item, 'job', None)) should_stop = False finally: warc_records_queue.task_done() @@ -592,7 +591,7 @@ async def write_warc_records( finally: await writer.close() - return {'writer_id': writer_id, 'stats': tracker.get_stats()} + return {'stats': tracker.get_stats()} def log_reader(self, reader_id: int, counter: int, tracker: ThroughputTracker): """Log progress every N items.""" @@ -607,13 +606,12 @@ def log_reader(self, reader_id: int, counter: int, tracker: ThroughputTracker): stats['requests_per_sec'], ) - def log_writer(self, writer_id: int, counter: int, tracker: ThroughputTracker): + def log_writer(self, counter: int, tracker: ThroughputTracker): """Log progress every N items.""" if self.log_every_n > 0 and counter % self.log_every_n == 0: stats = tracker.get_stats() logger.info( - 'WARC Writer %d: %d items, %.1f MB written, %.2f MB/s', - writer_id, + 'WARC Writer: %d items, %.1f MB written, %.2f MB/s', counter, stats['total_bytes'] / (1024 * 1024), stats['mb_per_sec'], diff --git a/cdx_toolkit/filter_warc/warc_utils.py b/cdx_toolkit/filter_warc/warc_utils.py index 495114b..1ba5cf4 100644 --- a/cdx_toolkit/filter_warc/warc_utils.py +++ b/cdx_toolkit/filter_warc/warc_utils.py @@ -66,16 +66,16 @@ def get_metadata_record_from_path( def generate_warc_filename( dest_prefix: str, - writer_id: int, sequence: int, writer_subprefix: Optional[str] = None, gzip: bool = False, ) -> str: - """Generate a WARC file name based a on prefix (can be a full path), write ID and sequence index.""" + """Generate a WARC file name from a prefix (can be a full path), an optional subprefix + (used to distinguish per-process shards), and a sequence index (file-rotation counter).""" file_name = dest_prefix + '-' if writer_subprefix is not None: file_name += writer_subprefix + '-' - file_name += f'{writer_id:06d}-{sequence:03d}.warc' # TODO default warc command uses ".extracted.warc" + file_name += f'{sequence:03d}.warc' # TODO default warc command uses ".extracted.warc" if gzip: file_name += '.gz' @@ -83,7 +83,6 @@ def generate_warc_filename( async def create_new_writer_with_header( - writer_id: int, sequence: int, output_path_prefix: str, max_attempts: int, @@ -95,14 +94,27 @@ async def create_new_writer_with_header( gzip: bool = False, content_type: Optional[str] = None, s3_client=None, -) -> Tuple[Union[S3ShardWriter, LocalFileWriter], int, str]: - """Create a new WARC writer (local or S3) including file header.""" + warcinfo_record_id: Optional[str] = None, + warcinfo_filename: Optional[str] = None, + write_warcinfo: bool = True, +) -> Tuple[Union[S3ShardWriter, LocalFileWriter], int, Optional[str]]: + """Create a new WARC writer (local or S3) including file header. + + warcinfo controls (used when output is sharded across processes and merged into one + file afterwards): + - write_warcinfo: if False, no warcinfo record is written (the file starts directly + with response records). The shared warcinfo_record_id is still returned so any + metadata records link to the single warcinfo that survives the merge. + - warcinfo_record_id: force this WARC-Record-ID on the warcinfo record so it is + identical across shard processes (merged file then has one canonical warcinfo id). + - warcinfo_filename: value of the warcinfo `filename` field; set to the final merged + filename when shards are concatenated afterwards (defaults to this file's name). + """ if is_s3_url(output_path_prefix): dest_bucket, dest_prefix = parse_s3_uri(output_path_prefix) warc_file_path = generate_warc_filename( dest_prefix=dest_prefix, - writer_id=writer_id, sequence=sequence, writer_subprefix=writer_subprefix, gzip=gzip, @@ -122,7 +134,6 @@ async def create_new_writer_with_header( # local file system warc_file_path = generate_warc_filename( dest_prefix=output_path_prefix, - writer_id=writer_id, sequence=sequence, writer_subprefix=writer_subprefix, gzip=gzip, @@ -137,14 +148,25 @@ async def create_new_writer_with_header( # Initialize writer await new_writer.start() - # Write WARC header - warc_file_name = warc_file_path.split('/')[-1] + if not write_warcinfo: + # Shard that intentionally omits the warcinfo record (a non-first shard that + # will be concatenated after the shard which already carries it). Return the + # shared id so metadata records still link to the single surviving warcinfo. + return new_writer, 0, warcinfo_record_id + + # Write WARC header. The `filename` field names the final WARC file, which may + # differ from this shard's object name when shards are merged afterwards. + warc_file_name = warcinfo_filename or warc_file_path.split('/')[-1] buffer = BytesIO() warc_writer = WARCWriter(buffer, gzip=gzip, warc_version=warc_version) warcinfo = warc_writer.create_warcinfo_record( - filename=warc_file_name, # only the file name and not the full path + filename=warc_file_name, info=writer_info, ) + if warcinfo_record_id: + # Force a caller-supplied, stable WARC-Record-ID so it is identical across shard + # processes -- the merged file then has exactly one canonical warcinfo id. + warcinfo.rec_headers.replace_header('WARC-Record-ID', warcinfo_record_id) warc_writer.write_record(warcinfo) header_data = buffer.getvalue() await new_writer.write(header_data) diff --git a/docs/notes/warc-fetcher-performance.md b/docs/notes/warc-fetcher-performance.md index 6f06db2..409eca0 100644 --- a/docs/notes/warc-fetcher-performance.md +++ b/docs/notes/warc-fetcher-performance.md @@ -95,9 +95,12 @@ Run a sample and watch `htop`/`mpstat` during the fetch: - **`LocalFileWriter`** (`local_writer.py`) is `aiofiles` with an 8 KiB buffer and a `flush()` per write — on a fast read stream the per-flush overhead + EBS limits are the more likely bottleneck (prefer writing to `s3://` on a c5n). -- Writers are **few** (`num_readers/6`, e.g. 5) and batch into 5 MiB parts, so the CPU - hot spot is the **read side** (many small GETs + TLS), not the write side. More output - shards cannot relieve a read/TLS-bound core. +- **[done] One writer per process.** Because writing is a cheap verbatim copy that one + coroutine sustains, the old `num_writers = num_readers/6` multi-writer fan-out (multiple + shards on one event loop) bought nothing but scheduling overhead and was removed. Each + process now has exactly one writer = one output shard; multi-core scaling and sharding + come from `--processes` (multiple event loops), and shards are merged into one file. The + CPU hot spot is the **read side** (many small GETs + TLS), not the write side. The flip side: the shard model makes **multi-process essentially free on the output side** — shards are independent files with no cross-shard state and no merge step, so N @@ -127,24 +130,33 @@ Assume parallelism is already configured (`--parallel` high) and reads go to absorbs the gap bytes. Extreme case: a **density heuristic** — when a large fraction of a ~1 GB file is needed, GET the whole file once and slice. -3. **[proposed] Use all cores via multi-process sharding** — *only if a sample shows - single-core CPU-bound.* Split the sorted job file by `warc_filename` into N worker - processes, each with its own event loop, S3 client, and output shards. Multiplies the - request-rate ceiling ~linearly with cores. Coalescing (2) may keep you network-bound - and make this unnecessary — measure first. - -4. **[proposed] uvloop.** Cheap single-core lift for the request-rate-bound regime; - stacks under multi-process. +3. **[done] Use all cores via multi-process sharding** — confirmed single-core CPU-bound + (a single loop saturates one core at ~450 small GET/s; one process pins ~1 core at + `--parallel 96`). Implemented natively as `cdxt repackage --processes N` + (`filter_warc/multiprocess.py`): the source is drained once and sharded by + `warc_filename`, one worker process per core fetches its shard + (`--parallel_readers R`, one writer each → one shard/proc), and the shards are merged + into a single `.warc.gz` (`filter_warc/merge.py`: server-side S3 + `UploadPartCopy`, or fsspec streaming locally). Measured on c5n.xlarge (4 vCPU): + ~457 → ~1130 rec/s fetch (2.6×, cores ~97 %) for a ~1 GiB / 35.7 k-record homepages job; + ~60 s end-to-end including source sharding and the merge. Sublinear because each + process's producer thread (CSV sort + per-row cross-thread enqueue) competes with its + event loop for the GIL — see latent issues. + +4. **[done] uvloop.** Gated by `CDXT_UVLOOP=1` in `WARCFilter.filter` (`warc_filter.py`). + Measured ~+8 % single-core; stacks under multi-process. 5. **[proposed] Bound coalesced read size / watch RAM.** A c5n.xlarge has 10.5 GiB; whole-file (~1 GB) reads × many readers will OOM. Cap superrange size and gap threshold; revisit `warc_records_queue_size` (`warc_filter.py:~62`, default 200) since each buffered payload is larger after coalescing. -6. **[proposed] Write output to `s3://` in-region, not local EBS.** c5n.xlarge is - EBS-only; gp3 baseline (~125 MB/s) can bottleneck writers at multi-Gbps read rates. - The `S3ShardWriter` MPU path keeps everything in-region and off EBS. Also revisit the - `writers = readers/6` heuristic for in-region fast reads. +6. **[done] Write output to `s3://` in-region, not local EBS.** Measured: S3-direct MPU + output is within noise of local EBS at this write rate (~35 MB/s total) and skips the + EBS round-trip + separate upload, so prefer `--prefix s3://…`. On the writer count: at + in-region read rates a *single* writer per process keeps up easily (verbatim copy, + 5 MiB parts), so `--parallel_writers 1` is both fastest-enough and gives one shard per + process — better than the `writers = readers/6` default for this workload. 7. **[proposed] Instance selection follows the constraint.** Because the binding limit is request-rate/CPU (not bandwidth), **more/faster cores beat more Gbps**. c5n.xlarge's @@ -173,10 +185,10 @@ Assume parallelism is already configured (`--parallel` high) and reads go to ## TL;DR 1. **[done]** Sort by `(warc_filename, offset)` — locality + enables coalescing. -2. **[proposed]** Coalesce adjacent ranges — the big request-rate win in-region. -3. **[proposed]** Multi-process by filename — only after confirming single-core CPU-bound; - the output shard model already supports it. -4. **[proposed]** uvloop, bounded read size, `s3://` output, dedup — supporting wins. +2. **[done]** Multi-process by filename (`--processes N`) — ~2.6× on 4 cores; shards merged + into one file. The big confirmed win in-region. +3. **[done]** uvloop (`CDXT_UVLOOP=1`), one writer per process, `s3://` output — supporting wins. +4. **[proposed]** Coalesce adjacent ranges, bounded read size, dedup — remaining ideas. Always **measure the regime first** (`htop` during a sample run): in-region the limiter is almost always request-rate on one core, not bandwidth. diff --git a/tests/filter_warc/test_command.py b/tests/filter_warc/test_command.py index 5648217..808fdee 100644 --- a/tests/filter_warc/test_command.py +++ b/tests/filter_warc/test_command.py @@ -19,7 +19,7 @@ def assert_cli_warc_by_cdx( caplog, extra_args: Optional[List[str]] = None, # warc_filename: str = 'TEST_warc_by_index-000000.warc.gz', - warc_filename: str = 'TEST_warc_by_index-000000-001.warc.gz', # due to parallel writer + warc_filename: str = 'TEST_warc_by_index-001.warc.gz', ): # test cli and check output index_path = fixture_path / 'filtered_CC-MAIN-2024-30_cdx-00187.gz' @@ -208,7 +208,7 @@ def test_cli_repackage_csv_roundtrip(tmpdir): ] ) - warc_path = os.path.join(base_prefix, 'TEST_warc_by_index-000000-001.warc.gz') + warc_path = os.path.join(base_prefix, 'TEST_warc_by_index-001.warc.gz') _assert_repackaged_warc(warc_path, metadata_record_path) @@ -232,7 +232,7 @@ def test_cli_repackage_csv_roundtrip_self_contained(tmpdir): ] ) - warc_path = os.path.join(base_prefix, 'TEST_warc_by_index-000000-001.warc.gz') + warc_path = os.path.join(base_prefix, 'TEST_warc_by_index-001.warc.gz') _assert_repackaged_warc(warc_path, metadata_record_path) @@ -328,7 +328,7 @@ def test_warc_by_cdx_subprefix_and_metadata(tmpdir): ) # Check that WARC file was created with subprefix - warc_path = os.path.join(tmpdir, 'TEST-SUB-000000-001.warc.gz') + warc_path = os.path.join(tmpdir, 'TEST-SUB-001.warc.gz') assert os.path.exists(warc_path) # Validate metadata in warcinfo record @@ -359,7 +359,7 @@ def test_warc_by_cdx_without_creator_operator(tmpdir): ) # Check that WARC file was created - warc_path = os.path.join(tmpdir, 'TEST_NO_META-000000-001.warc.gz') + warc_path = os.path.join(tmpdir, 'TEST_NO_META-001.warc.gz') assert os.path.exists(warc_path) # Validate that creator/operator are not in warcinfo record @@ -383,7 +383,7 @@ def test_cli_warc_by_athena( base_prefix = tmpdir warc_download_prefix = 's3://commoncrawl' extra_args: Optional[List[str]] = None - warc_filename: str = 'TEST_warc_by_index-000000-001.extracted.warc.gz' # due to parallel writer + warc_filename: str = 'TEST_warc_by_index-001.warc.gz' base_prefix = str(base_prefix) if extra_args is None: diff --git a/tests/filter_warc/test_sql_sources_gated.py b/tests/filter_warc/test_sql_sources_gated.py index 8993fcd..13c4f70 100644 --- a/tests/filter_warc/test_sql_sources_gated.py +++ b/tests/filter_warc/test_sql_sources_gated.py @@ -41,7 +41,7 @@ def _produce_and_consume(tmpdir, produce_args): ] ) - warc_path = os.path.join(base_prefix, 'TEST_sql-000000-001.warc.gz') + warc_path = os.path.join(base_prefix, 'TEST_sql-001.warc.gz') response_count = 0 with fsspec.open(warc_path, 'rb') as stream: for record in ArchiveIterator(stream): diff --git a/tests/filter_warc/test_warc_filter.py b/tests/filter_warc/test_warc_filter.py index b157f61..a9e9cf9 100644 --- a/tests/filter_warc/test_warc_filter.py +++ b/tests/filter_warc/test_warc_filter.py @@ -56,7 +56,6 @@ async def run_test(): current_file_sequence=current_file_sequence, current_file_size=current_file_size, added_byte_size=added_byte_size, - writer_id=1, output_path_prefix='/fake/output', max_attempts=3, base_backoff_seconds=1.0, @@ -103,7 +102,6 @@ async def run_test(): current_file_sequence=current_file_sequence, current_file_size=current_file_size, added_byte_size=added_byte_size, - writer_id=1, output_path_prefix='/fake/output', max_attempts=3, base_backoff_seconds=1.0, @@ -122,7 +120,6 @@ async def run_test(): # New writer should be created mock_create.assert_called_once_with( sequence=current_file_sequence + 1, - writer_id=1, output_path_prefix='/fake/output', max_attempts=3, base_backoff_seconds=1.0, @@ -163,7 +160,6 @@ async def run_test(): current_file_sequence=current_file_sequence, current_file_size=current_file_size, added_byte_size=added_byte_size, - writer_id=1, output_path_prefix='/fake/output', max_attempts=3, base_backoff_seconds=1.0, @@ -210,7 +206,6 @@ async def run_test(): current_file_sequence=current_file_sequence, current_file_size=current_file_size, added_byte_size=added_byte_size, - writer_id=1, output_path_prefix='/fake/output', max_attempts=3, base_backoff_seconds=1.0, @@ -251,7 +246,6 @@ async def run_test(): current_file_sequence=current_file_sequence, current_file_size=current_file_size, added_byte_size=added_byte_size, - writer_id=1, output_path_prefix='/fake/output', max_attempts=3, base_backoff_seconds=1.0, @@ -297,7 +291,6 @@ async def run_test(): current_file_sequence=current_file_sequence, current_file_size=current_file_size, added_byte_size=added_byte_size, - writer_id=1, output_path_prefix='/fake/output', max_attempts=3, base_backoff_seconds=1.0, @@ -337,7 +330,6 @@ async def run_test(): current_file_sequence=1, current_file_size=800, added_byte_size=300, - writer_id=99, output_path_prefix='/custom/output', max_attempts=5, base_backoff_seconds=2.5, @@ -351,7 +343,6 @@ async def run_test(): # Verify all kwargs are passed through mock_create.assert_called_once_with( sequence=2, # incremented from 1 - writer_id=99, output_path_prefix='/custom/output', max_attempts=5, base_backoff_seconds=2.5, @@ -390,7 +381,6 @@ async def run_test(): current_file_sequence=5, current_file_size=800, added_byte_size=300, - writer_id=1, output_path_prefix='/fake/output', max_attempts=3, base_backoff_seconds=1.0, @@ -416,11 +406,11 @@ def test_log_writer(caplog): log_every_n=2, ) tracker = ThroughputTracker() - warc_filter.log_writer(1, 0, tracker) - warc_filter.log_writer(1, 1, tracker) - warc_filter.log_writer(1, 2, tracker) + warc_filter.log_writer(0, tracker) + warc_filter.log_writer(1, tracker) + warc_filter.log_writer(2, tracker) - assert caplog.text.count('WARC Writer 1') == 2 + assert caplog.text.count('WARC Writer') == 2 def test_log_reader(caplog):