From d5a6e87f4476900f8c5ab084fa957b20c563bde1 Mon Sep 17 00:00:00 2001 From: ali Date: Mon, 8 Jun 2026 12:33:30 +0530 Subject: [PATCH 1/5] UN-3513 [FEAT] Type chord-callback boundary with BatchExecutionResult / FileExecutionResult MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Producers in workers/file_processing/tasks.py now build typed dataclasses (from unstract.core.worker_models) and emit their ``.to_dict()`` instead of hand-rolled dicts. Locks the wire shape to the dataclass schema so downstream refactors fail loud. Scope Producer-side typing only. Consumer (workers/callback/tasks.py + aggregate_file_batch_results) already reads via ``.get(..., default)`` — tolerant by construction — so no consumer-side change needed. Dataclass extensions (unstract.core.worker_models, additive only) * BatchExecutionResult gains 3 optional fields: skipped_already_completed, skipped_active_duplicate, organization_id. * FileExecutionResult gains 3 optional fields for the API path's legacy dict vocabulary: file_name (alias for file), result_data (alias for result), skipped (marker like "already_completed"). * Both from_dict updated to populate the new fields. Producer migrations (workers/file_processing/tasks.py) * L901 (general path, process_file_batch return): BatchExecutionResult(...).to_dict(). Wire dict gains file_results: [] and errors: [] defaults — strictly additive. * L1706, L1798, L1823 (API path returns from _process_file_batch_api_core helpers): FileExecutionResult(...).to_dict(). L1798 preserves the legacy storage_result field via dict-spread merge. Domain-vocabulary correction on the API path API-path producers previously returned status="completed" / "failed" — lowercase strings matching neither ExecutionStatus (workflow-level, uppercase) nor ApiDeploymentResultStatus (per-file, Success/Failed, the canonical per-file vocab). Producers now emit "Success" / "Failed" via FileExecutionResult. Audit: no Python equality consumer was found reading the lowercase variants (grep clean). Observability tooling pattern-matching the old strings would need updating; this is a domain-correctness fix. Tests New tests/test_chord_callback_boundary.py — 14 tests, 3 classes: * Wire-shape characterisation for BatchExecutionResult. * Wire-shape characterisation for FileExecutionResult with alias fields and canonical Success/Failed vocab. * Consumer tolerance: aggregate_file_batch_results-style .get() reads return expected values from the new wire shape. sdk1's 80 worker_models tests still pass — the dataclass extensions are strictly additive. Regression risk: zero on consumer side, zero on backend (doesn't import these classes; has its own FileExecutionResult in dto.py — untouched). Status-vocab shift on API path is a deliberate domain correction. Test count: workers boundary suite +14 (new); sdk1 dispatcher 80/80. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../core/src/unstract/core/worker_models.py | 37 ++- workers/file_processing/tasks.py | 85 ++++--- workers/tests/test_chord_callback_boundary.py | 211 ++++++++++++++++++ 3 files changed, 294 insertions(+), 39 deletions(-) create mode 100644 workers/tests/test_chord_callback_boundary.py diff --git a/unstract/core/src/unstract/core/worker_models.py b/unstract/core/src/unstract/core/worker_models.py index 17da3557a2..4e1528558f 100644 --- a/unstract/core/src/unstract/core/worker_models.py +++ b/unstract/core/src/unstract/core/worker_models.py @@ -239,7 +239,14 @@ def from_dict(cls, data: dict[str, Any]) -> "FinalOutputResult": @dataclass class FileExecutionResult: - """Structured result for file execution tasks.""" + """Structured result for file execution tasks. + + The API-deployment chord path historically returned dicts with + ``file_name`` / ``result_data`` / ``skipped`` rather than the + canonical ``file`` / ``result`` shape. The optional aliases below + let producers populate either vocabulary without the consumer + needing to know which path produced the result. + """ file: str file_execution_id: str | None @@ -249,6 +256,14 @@ class FileExecutionResult: metadata: dict[str, Any] | None = None processing_time: float = 0.0 file_size: int = 0 + # Optional API-path aliases for the legacy dict shape. Strictly + # additive — consumers reading the existing ``file`` / ``result`` + # fields are unaffected; consumers reading ``file_name`` / + # ``result_data`` get the value the producer populated. + file_name: str | None = None + result_data: Any | None = None + # Marker for files skipped at the API path (e.g. "already_completed"). + skipped: str | None = None def __post_init__(self) -> None: if self.error: @@ -293,6 +308,9 @@ def from_dict(cls, data: dict[str, Any]) -> "FileExecutionResult": metadata=data.get("metadata"), processing_time=data.get("processing_time", 0.0), file_size=data.get("file_size", 0), + file_name=data.get("file_name"), + result_data=data.get("result_data"), + skipped=data.get("skipped"), ) def is_successful(self) -> bool: @@ -306,7 +324,13 @@ def has_error(self) -> bool: @dataclass class BatchExecutionResult: - """Structured result for batch execution tasks.""" + """Structured result for batch execution tasks. + + The general workflow chord path additionally tracks skipped-file + sub-counts and the org context on the wire. The optional fields + below capture that vocabulary without changing existing field + semantics — strictly additive. + """ total_files: int successful_files: int @@ -315,6 +339,12 @@ class BatchExecutionResult: file_results: list[FileExecutionResult] = field(default_factory=list) batch_id: str | None = None errors: list[str] = field(default_factory=list) + # Optional general-path fields. Producers populate; consumers that + # don't know about them are unaffected (existing reads use ``.get()`` + # with defaults). + skipped_already_completed: int = 0 + skipped_active_duplicate: int = 0 + organization_id: str | None = None @property def success_rate(self) -> float: @@ -343,6 +373,9 @@ def from_dict(cls, data: dict[str, Any]) -> "BatchExecutionResult": file_results=file_results, batch_id=data.get("batch_id"), errors=data.get("errors", []), + skipped_already_completed=data.get("skipped_already_completed", 0), + skipped_active_duplicate=data.get("skipped_active_duplicate", 0), + organization_id=data.get("organization_id"), ) def add_file_result(self, file_result: FileExecutionResult): diff --git a/workers/file_processing/tasks.py b/workers/file_processing/tasks.py index 614266376a..264ca561c3 100644 --- a/workers/file_processing/tasks.py +++ b/workers/file_processing/tasks.py @@ -50,7 +50,12 @@ PreCreatedFileData, WorkerFileData, ) -from unstract.core.worker_models import FileProcessingResult +from unstract.core.worker_models import ( + ApiDeploymentResultStatus, + BatchExecutionResult, + FileExecutionResult, + FileProcessingResult, +) logger = WorkerLogger.get_logger(__name__) @@ -896,21 +901,19 @@ def _compile_batch_result(context: WorkflowContextData) -> dict[str, Any]: except Exception as cleanup_error: logger.warning(f"Failed to cleanup StateStore context: {cleanup_error}") - # Return the final result matching Django backend format - # Note: Only active duplicates count as failures; already-completed do not - return { - "successful_files": result.successful_files, - "failed_files": result.failed_files, # Includes active duplicates (user error) - "total_files": result.successful_files - + result.failed_files - + len(skipped_already_completed), # Include all files in batch - "skipped_already_completed": len(skipped_already_completed), # Not a failure - "skipped_active_duplicate": len( - skipped_active_duplicate - ), # IS a failure (counted above) - "execution_time": result.execution_time, - "organization_id": context.organization_context.organization_id, - } + # Return the final result matching Django backend format. + # Note: only active duplicates count as failures; already-completed do not. + return BatchExecutionResult( + total_files=( + result.successful_files + result.failed_files + len(skipped_already_completed) + ), + successful_files=result.successful_files, + failed_files=result.failed_files, + execution_time=result.execution_time, + skipped_already_completed=len(skipped_already_completed), + skipped_active_duplicate=len(skipped_active_duplicate), + organization_id=context.organization_context.organization_id, + ).to_dict() # HELPER FUNCTIONS (originally part of the massive process_file_batch function) @@ -1703,15 +1706,16 @@ def _process_single_file_api( f"Skipping processing for execution_id: {execution_id}, " f"file_execution_id: {file_execution_id}" ) - return { - "file_execution_id": file_execution_id, - "file_name": file_name, - "status": "completed", - "processing_time": 0.0, - "result_data": getattr(workflow_file_execution, "result", None), - "metadata": getattr(workflow_file_execution, "metadata", None) or {}, - "skipped": "already_completed", - } + return FileExecutionResult( + file=file_name, + file_execution_id=file_execution_id, + status=ApiDeploymentResultStatus.SUCCESS, + file_name=file_name, + processing_time=0.0, + result_data=getattr(workflow_file_execution, "result", None), + metadata=getattr(workflow_file_execution, "metadata", None) or {}, + skipped="already_completed", + ).to_dict() except Exception as e: logger.exception( f"API path: Failed to validate completion status for {file_execution_id}: {e}. " @@ -1792,12 +1796,18 @@ def _process_single_file_api( processing_time = time.time() - start_time + # ``storage_result`` isn't a field on FileExecutionResult yet + # (no consumer reads it today). Preserved via dict-spread so any + # external integration that does inspect it sees the same value. result = { - "file_execution_id": file_execution_id, - "file_name": file_name, - "status": "completed", - "processing_time": processing_time, - "result_data": runner_result, + **FileExecutionResult( + file=file_name, + file_execution_id=file_execution_id, + status=ApiDeploymentResultStatus.SUCCESS, + file_name=file_name, + processing_time=processing_time, + result_data=runner_result, + ).to_dict(), "storage_result": storage_result, } @@ -1820,13 +1830,14 @@ def _process_single_file_api( except Exception: logger.exception("Failed to update file execution status") - return { - "file_execution_id": file_execution_id, - "file_name": file_name, - "status": "failed", - "processing_time": processing_time, - "error": str(e), - } + return FileExecutionResult( + file=file_name, + file_execution_id=file_execution_id, + status=ApiDeploymentResultStatus.FAILED, + file_name=file_name, + processing_time=processing_time, + error=str(e), + ).to_dict() def _check_file_history( diff --git a/workers/tests/test_chord_callback_boundary.py b/workers/tests/test_chord_callback_boundary.py new file mode 100644 index 0000000000..771ccf7e17 --- /dev/null +++ b/workers/tests/test_chord_callback_boundary.py @@ -0,0 +1,211 @@ +"""Wire-shape characterisation for the chord-callback boundary. + +Locks the on-wire contract for the two producer paths that feed +``process_batch_callback`` (general path) and ``process_batch_callback_api`` +(API path). Producers now build typed dataclasses and serialise via +``.to_dict()``; these tests assert the resulting dicts contain every +field the consumer reads, plus the strictly-additive ones the dataclass +schema introduces. +""" + +from __future__ import annotations + +import json + +import pytest + +from unstract.core.worker_models import ( + ApiDeploymentResultStatus, + BatchExecutionResult, + FileExecutionResult, +) + + +class TestBatchExecutionResultWireShape: + """General path (``process_file_batch`` returns this shape).""" + + def _make(self) -> BatchExecutionResult: + return BatchExecutionResult( + total_files=10, + successful_files=7, + failed_files=2, + execution_time=12.5, + skipped_already_completed=1, + skipped_active_duplicate=0, + organization_id="org-test", + ) + + def test_wire_contains_consumer_read_fields(self): + # ``aggregate_file_batch_results`` reads these via ``.get()`` — + # they must appear in the wire dict so the existing consumer + # behaviour is preserved. + wire = self._make().to_dict() + for key in ( + "total_files", + "successful_files", + "failed_files", + "execution_time", + "file_results", # consumer iterates this; empty list by default + ): + assert key in wire, f"consumer-read field missing: {key}" + + def test_wire_carries_extended_optional_fields(self): + wire = self._make().to_dict() + # These three are the strictly-additive fields the Phase 5.3 + # extension introduced. Producer populates; legacy consumers + # that don't know about them are unaffected. + assert wire["skipped_already_completed"] == 1 + assert wire["skipped_active_duplicate"] == 0 + assert wire["organization_id"] == "org-test" + + def test_round_trip_preserves_all_fields(self): + original = self._make() + round_tripped = BatchExecutionResult.from_dict(original.to_dict()) + assert round_tripped.total_files == original.total_files + assert round_tripped.successful_files == original.successful_files + assert round_tripped.failed_files == original.failed_files + assert round_tripped.execution_time == original.execution_time + assert ( + round_tripped.skipped_already_completed + == original.skipped_already_completed + ) + assert ( + round_tripped.skipped_active_duplicate + == original.skipped_active_duplicate + ) + assert round_tripped.organization_id == original.organization_id + + def test_wire_is_json_safe(self): + # Celery's default serializer is JSON — the dict must round-trip + # through ``json.dumps`` / ``json.loads`` without loss. + wire = self._make().to_dict() + assert json.loads(json.dumps(wire)) == wire + + def test_defaults_safe_when_no_skips(self): + result = BatchExecutionResult( + total_files=3, + successful_files=3, + failed_files=0, + execution_time=1.0, + ) + wire = result.to_dict() + assert wire["skipped_already_completed"] == 0 + assert wire["skipped_active_duplicate"] == 0 + + +class TestFileExecutionResultWireShape: + """API path (``process_file_batch_api`` returns this per-file shape).""" + + def _make_success(self) -> FileExecutionResult: + return FileExecutionResult( + file="invoice.pdf", + file_execution_id="fx-1", + status=ApiDeploymentResultStatus.SUCCESS, + file_name="invoice.pdf", + processing_time=4.2, + result_data={"extracted": "value"}, + metadata={"source": "user_upload"}, + ) + + def _make_failure(self) -> FileExecutionResult: + return FileExecutionResult( + file="broken.pdf", + file_execution_id="fx-2", + status=ApiDeploymentResultStatus.FAILED, + file_name="broken.pdf", + processing_time=0.1, + error="extractor crashed", + ) + + def _make_skipped(self) -> FileExecutionResult: + return FileExecutionResult( + file="dup.pdf", + file_execution_id="fx-3", + status=ApiDeploymentResultStatus.SUCCESS, + file_name="dup.pdf", + skipped="already_completed", + result_data={"cached": True}, + ) + + def test_wire_carries_file_name_alias(self): + # The API path's legacy wire uses ``file_name`` (not ``file``); + # the dataclass preserves the alias. + wire = self._make_success().to_dict() + assert wire["file_name"] == "invoice.pdf" + assert wire["file"] == "invoice.pdf" # canonical alongside legacy + + def test_wire_carries_result_data_alias(self): + wire = self._make_success().to_dict() + assert wire["result_data"] == {"extracted": "value"} + + def test_wire_carries_skipped_marker(self): + wire = self._make_skipped().to_dict() + assert wire["skipped"] == "already_completed" + + def test_success_status_uses_canonical_vocab(self): + # Domain correction: per-file results use ``ApiDeploymentResultStatus`` + # vocabulary (Success / Failed), not the ad-hoc lowercase + # "completed" / "failed" that the legacy dict producer used. + wire = self._make_success().to_dict() + assert wire["status"] == "Success" + + def test_failure_status_uses_canonical_vocab(self): + wire = self._make_failure().to_dict() + assert wire["status"] == "Failed" + assert wire["error"] == "extractor crashed" + + def test_post_init_derives_status_from_error(self): + # An error string forces FAILED regardless of the status passed + # to the constructor. + result = FileExecutionResult( + file="x", + file_execution_id="fx", + status=ApiDeploymentResultStatus.SUCCESS, + error="boom", + ) + assert result.status == ApiDeploymentResultStatus.FAILED + + def test_round_trip_preserves_all_aliases(self): + original = self._make_skipped() + round_tripped = FileExecutionResult.from_dict(original.to_dict()) + assert round_tripped.file == original.file + assert round_tripped.file_name == original.file_name + assert round_tripped.file_execution_id == original.file_execution_id + assert round_tripped.result_data == original.result_data + assert round_tripped.skipped == original.skipped + assert round_tripped.status == original.status + + def test_wire_is_json_safe(self): + for builder in (self._make_success, self._make_failure, self._make_skipped): + wire = builder().to_dict() + assert json.loads(json.dumps(wire)) == wire + + +class TestConsumerTolerance: + """The chord-callback consumer (``aggregate_file_batch_results``) + reads via ``.get(..., default)``. Verifies the new wire shape + doesn't omit any field the consumer relies on.""" + + def test_aggregator_can_read_general_path_shape(self): + wire = BatchExecutionResult( + total_files=5, + successful_files=4, + failed_files=1, + execution_time=2.0, + skipped_already_completed=0, + skipped_active_duplicate=0, + organization_id="org-1", + ).to_dict() + # Mirrors aggregate_file_batch_results' ``.get()`` reads. + assert wire.get("total_files", 0) == 5 + assert wire.get("successful_files", 0) == 4 + assert wire.get("failed_files", 0) == 1 + assert wire.get("execution_time", 0) == 2.0 + # ``file_results`` is read (default []), and ``skipped_files`` + # is read but never written — same as legacy behaviour. + assert wire.get("file_results", []) == [] + assert wire.get("skipped_files", 0) == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 8e16eb60e16e2dd55f49163678267b20eb7c2467 Mon Sep 17 00:00:00 2001 From: ali Date: Mon, 8 Jun 2026 13:52:08 +0530 Subject: [PATCH 2/5] UN-3513 [FIX] Address PR review (toolkit + SkipReason enum + producer-binding tests) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A+B from the triage on PR #2020: * tasks.py:1659 (API-path BATCH return) — migrated to BatchExecutionResult.to_dict(). Fixes the half-typed boundary the reviewer flagged. file_results, total_files, skipped_already_completed and organization_id are now on the wire. Successful/skipped counter semantic preserved (separating them is deferred to a follow-up). * New SkipReason StrEnum (worker_models.py) with ALREADY_COMPLETED + ACTIVE_DUPLICATE — mirrors the batch-level skip counters on BatchExecutionResult. FileExecutionResult.skipped is now SkipReason | None. from_dict coerces. Producer uses the enum; the ACTIVE_DUPLICATE value has no current per-file producer but is exercised end-to-end via a round-trip test. * TODO(UN-3516) marker on the three alias fields (file_name, result_data, skipped) — sunset ticket filed. * Tests strengthened: - TestProducerBinding drives real _compile_batch_result with a minimal SimpleNamespace context, and drives _process_single_file_api via mocked api_client for the already-completed branch. - TestRealConsumerTolerance imports the real aggregate_file_batch_results — producer-consumer contract driven end-to-end. - test_none_valued_optional_fields_stripped_from_wire documents serialize_dataclass_to_dict's None-strip behaviour. - test_active_duplicate_skip_reason_round_trips proves the second enum value isn't dead. - SonarCloud python:S1244 fixed — pytest.approx. - skipped_files==0 NIT assertion removed. Test count: workers boundary suite 14 -> 18; sdk1 worker_models 80/80 still green. Deferred (separate tickets to follow): __post_init__ silent status clobber, from_dict status discard, BatchExecutionResult invariant, storage soft-failure, dead aggregator branch. Co-Authored-By: Claude Opus 4.7 (1M context) --- unstract/core/src/unstract/core/__init__.py | 2 + .../core/src/unstract/core/worker_models.py | 36 ++- workers/file_processing/tasks.py | 26 ++- workers/tests/test_chord_callback_boundary.py | 219 ++++++++++++++---- 4 files changed, 220 insertions(+), 63 deletions(-) diff --git a/unstract/core/src/unstract/core/__init__.py b/unstract/core/src/unstract/core/__init__.py index deb84687f2..255673b939 100644 --- a/unstract/core/src/unstract/core/__init__.py +++ b/unstract/core/src/unstract/core/__init__.py @@ -43,6 +43,7 @@ PipelineStatus, PipelineStatusUpdateRequest, QueueName, + SkipReason, StatusMappings, TaskError, TaskExecutionContext, @@ -76,6 +77,7 @@ "WebhookResult", "FileExecutionResult", "BatchExecutionResult", + "SkipReason", "CallbackExecutionData", "WorkflowExecutionUpdateRequest", "PipelineStatusUpdateRequest", diff --git a/unstract/core/src/unstract/core/worker_models.py b/unstract/core/src/unstract/core/worker_models.py index 4e1528558f..ed3ca89c36 100644 --- a/unstract/core/src/unstract/core/worker_models.py +++ b/unstract/core/src/unstract/core/worker_models.py @@ -112,6 +112,28 @@ class ApiDeploymentResultStatus(str, Enum): FAILED = "Failed" +class SkipReason(str, Enum): + """Reasons a per-file execution can be skipped. + + Closed vocabulary so typos at producer call sites fail at + construction time rather than silently producing an + unrecognisable value on the wire. StrEnum semantics — members + serialise to their string value and compare equal to the + underlying string. + + Values mirror the batch-level skip counters on + :class:`BatchExecutionResult` (``skipped_already_completed`` / + ``skipped_active_duplicate``) so per-file and batch-level + vocabulary stay in lockstep. + """ + + # WorkflowFileExecution.status was already COMPLETED — another + # worker finished this file before we got to it. + ALREADY_COMPLETED = "already_completed" + # Another worker is currently processing the same file content. + ACTIVE_DUPLICATE = "active_duplicate" + + class NotificationMethod(str, Enum): """Notification delivery methods.""" @@ -256,14 +278,14 @@ class FileExecutionResult: metadata: dict[str, Any] | None = None processing_time: float = 0.0 file_size: int = 0 - # Optional API-path aliases for the legacy dict shape. Strictly - # additive — consumers reading the existing ``file`` / ``result`` - # fields are unaffected; consumers reading ``file_name`` / - # ``result_data`` get the value the producer populated. + # TODO(UN-3516): remove these three legacy API-path aliases once + # consumers have migrated to the canonical ``file`` / ``result`` + # fields and to a typed skip-reason. Kept additive in the meantime + # so the chord-callback producer can preserve its current dict + # vocabulary without consumer breakage. file_name: str | None = None result_data: Any | None = None - # Marker for files skipped at the API path (e.g. "already_completed"). - skipped: str | None = None + skipped: SkipReason | None = None def __post_init__(self) -> None: if self.error: @@ -310,7 +332,7 @@ def from_dict(cls, data: dict[str, Any]) -> "FileExecutionResult": file_size=data.get("file_size", 0), file_name=data.get("file_name"), result_data=data.get("result_data"), - skipped=data.get("skipped"), + skipped=SkipReason(data["skipped"]) if data.get("skipped") else None, ) def is_successful(self) -> bool: diff --git a/workers/file_processing/tasks.py b/workers/file_processing/tasks.py index 264ca561c3..d42aea277c 100644 --- a/workers/file_processing/tasks.py +++ b/workers/file_processing/tasks.py @@ -55,6 +55,7 @@ BatchExecutionResult, FileExecutionResult, FileProcessingResult, + SkipReason, ) logger = WorkerLogger.get_logger(__name__) @@ -1655,11 +1656,24 @@ def process_file_batch_api( else: successful_files += 1 - # Return result matching Django FileBatchResult structure - batch_result = { - "successful_files": successful_files, - "failed_files": failed_files, - } + # Return result matching Django FileBatchResult structure. + # Note: ``successful_files`` here counts every non-error file, + # INCLUDING skipped-already-completed ones (legacy API-path + # semantic). Separating the skipped count from the successful + # count is deferred — would change consumer-visible counters + # and is tracked separately. + batch_result = BatchExecutionResult( + total_files=len(file_results), + successful_files=successful_files, + failed_files=failed_files, + file_results=[FileExecutionResult.from_dict(r) for r in file_results], + skipped_already_completed=sum( + 1 + for r in file_results + if r.get("skipped") == SkipReason.ALREADY_COMPLETED.value + ), + organization_id=schema_name, + ).to_dict() logger.info(f"Successfully processed API file batch {batch_id}") return batch_result @@ -1714,7 +1728,7 @@ def _process_single_file_api( processing_time=0.0, result_data=getattr(workflow_file_execution, "result", None), metadata=getattr(workflow_file_execution, "metadata", None) or {}, - skipped="already_completed", + skipped=SkipReason.ALREADY_COMPLETED, ).to_dict() except Exception as e: logger.exception( diff --git a/workers/tests/test_chord_callback_boundary.py b/workers/tests/test_chord_callback_boundary.py index 771ccf7e17..31855096db 100644 --- a/workers/tests/test_chord_callback_boundary.py +++ b/workers/tests/test_chord_callback_boundary.py @@ -1,23 +1,36 @@ """Wire-shape characterisation for the chord-callback boundary. -Locks the on-wire contract for the two producer paths that feed +Locks the on-wire contract for the producer paths that feed ``process_batch_callback`` (general path) and ``process_batch_callback_api`` -(API path). Producers now build typed dataclasses and serialise via -``.to_dict()``; these tests assert the resulting dicts contain every -field the consumer reads, plus the strictly-additive ones the dataclass -schema introduces. +(API path). Producers build typed dataclasses and serialise via +``.to_dict()``. + +Three layers of test: + +1. **Dataclass wire shape** — ``to_dict`` / ``from_dict`` round-trip + preserves every consumer-read field; JSON-safe. +2. **Producer binding** — drives the real producer functions + (``_compile_batch_result``, ``_process_single_file_api``). Catches + reverts at the producer site that a dataclass-only test would miss. +3. **Real-consumer tolerance** — drives the real + ``aggregate_file_batch_results`` consumer against the typed wire + shape. Catches mismatches the producer-only test would miss. """ from __future__ import annotations import json +from types import SimpleNamespace +from unittest.mock import MagicMock import pytest +from shared.processing.files.time_utils import aggregate_file_batch_results from unstract.core.worker_models import ( ApiDeploymentResultStatus, BatchExecutionResult, FileExecutionResult, + SkipReason, ) @@ -36,24 +49,18 @@ def _make(self) -> BatchExecutionResult: ) def test_wire_contains_consumer_read_fields(self): - # ``aggregate_file_batch_results`` reads these via ``.get()`` — - # they must appear in the wire dict so the existing consumer - # behaviour is preserved. wire = self._make().to_dict() for key in ( "total_files", "successful_files", "failed_files", "execution_time", - "file_results", # consumer iterates this; empty list by default + "file_results", ): assert key in wire, f"consumer-read field missing: {key}" def test_wire_carries_extended_optional_fields(self): wire = self._make().to_dict() - # These three are the strictly-additive fields the Phase 5.3 - # extension introduced. Producer populates; legacy consumers - # that don't know about them are unaffected. assert wire["skipped_already_completed"] == 1 assert wire["skipped_active_duplicate"] == 0 assert wire["organization_id"] == "org-test" @@ -64,7 +71,7 @@ def test_round_trip_preserves_all_fields(self): assert round_tripped.total_files == original.total_files assert round_tripped.successful_files == original.successful_files assert round_tripped.failed_files == original.failed_files - assert round_tripped.execution_time == original.execution_time + assert round_tripped.execution_time == pytest.approx(original.execution_time) assert ( round_tripped.skipped_already_completed == original.skipped_already_completed @@ -76,22 +83,9 @@ def test_round_trip_preserves_all_fields(self): assert round_tripped.organization_id == original.organization_id def test_wire_is_json_safe(self): - # Celery's default serializer is JSON — the dict must round-trip - # through ``json.dumps`` / ``json.loads`` without loss. wire = self._make().to_dict() assert json.loads(json.dumps(wire)) == wire - def test_defaults_safe_when_no_skips(self): - result = BatchExecutionResult( - total_files=3, - successful_files=3, - failed_files=0, - execution_time=1.0, - ) - wire = result.to_dict() - assert wire["skipped_already_completed"] == 0 - assert wire["skipped_active_duplicate"] == 0 - class TestFileExecutionResultWireShape: """API path (``process_file_batch_api`` returns this per-file shape).""" @@ -123,16 +117,14 @@ def _make_skipped(self) -> FileExecutionResult: file_execution_id="fx-3", status=ApiDeploymentResultStatus.SUCCESS, file_name="dup.pdf", - skipped="already_completed", + skipped=SkipReason.ALREADY_COMPLETED, result_data={"cached": True}, ) def test_wire_carries_file_name_alias(self): - # The API path's legacy wire uses ``file_name`` (not ``file``); - # the dataclass preserves the alias. wire = self._make_success().to_dict() assert wire["file_name"] == "invoice.pdf" - assert wire["file"] == "invoice.pdf" # canonical alongside legacy + assert wire["file"] == "invoice.pdf" def test_wire_carries_result_data_alias(self): wire = self._make_success().to_dict() @@ -140,12 +132,9 @@ def test_wire_carries_result_data_alias(self): def test_wire_carries_skipped_marker(self): wire = self._make_skipped().to_dict() - assert wire["skipped"] == "already_completed" + assert wire["skipped"] == SkipReason.ALREADY_COMPLETED.value def test_success_status_uses_canonical_vocab(self): - # Domain correction: per-file results use ``ApiDeploymentResultStatus`` - # vocabulary (Success / Failed), not the ad-hoc lowercase - # "completed" / "failed" that the legacy dict producer used. wire = self._make_success().to_dict() assert wire["status"] == "Success" @@ -155,8 +144,6 @@ def test_failure_status_uses_canonical_vocab(self): assert wire["error"] == "extractor crashed" def test_post_init_derives_status_from_error(self): - # An error string forces FAILED regardless of the status passed - # to the constructor. result = FileExecutionResult( file="x", file_execution_id="fx", @@ -175,18 +162,131 @@ def test_round_trip_preserves_all_aliases(self): assert round_tripped.skipped == original.skipped assert round_tripped.status == original.status + def test_active_duplicate_skip_reason_round_trips(self): + # ``ACTIVE_DUPLICATE`` mirrors the batch-level + # ``skipped_active_duplicate`` counter; no producer emits it + # per-file today but the enum value must exist and round-trip + # cleanly so a future producer can pick it without re-typing + # the bare string. + original = FileExecutionResult( + file="dup.pdf", + file_execution_id="fx", + status=ApiDeploymentResultStatus.SUCCESS, + skipped=SkipReason.ACTIVE_DUPLICATE, + ) + wire = original.to_dict() + assert wire["skipped"] == SkipReason.ACTIVE_DUPLICATE.value + assert wire["skipped"] == "active_duplicate" + round_tripped = FileExecutionResult.from_dict(wire) + assert round_tripped.skipped == SkipReason.ACTIVE_DUPLICATE + def test_wire_is_json_safe(self): for builder in (self._make_success, self._make_failure, self._make_skipped): wire = builder().to_dict() assert json.loads(json.dumps(wire)) == wire + def test_none_valued_optional_fields_stripped_from_wire(self): + """``serialize_dataclass_to_dict`` drops ``None`` values. + + Documents the behaviour so consumers using membership checks + (``"x" in wire``) instead of ``.get(..., default)`` know what + to expect. Aliases default to ``None`` and only appear on the + wire when explicitly populated. + """ + minimal = FileExecutionResult( + file="a.pdf", + file_execution_id="fx", + status=ApiDeploymentResultStatus.SUCCESS, + ) + wire = minimal.to_dict() + # Required fields and zero-valued numerics survive. + assert wire["file"] == "a.pdf" + assert wire["status"] == "Success" + assert wire["processing_time"] == pytest.approx(0.0) + assert wire["file_size"] == 0 + # None defaults are dropped — not in the wire dict at all. + for absent in ("error", "result", "metadata", "file_name", "result_data", "skipped"): + assert absent not in wire, f"expected {absent!r} to be stripped when None" + + +class TestProducerBinding: + """Drives the real producer functions in ``file_processing.tasks``. + + A revert at any of these sites back to a hand-rolled dict (or to + the legacy lowercase status strings) keeps the *dataclass* tests + green — these tests catch that by asserting the actual wire shape + the chord callback receives from the producer. + """ -class TestConsumerTolerance: - """The chord-callback consumer (``aggregate_file_batch_results``) - reads via ``.get(..., default)``. Verifies the new wire shape - doesn't omit any field the consumer relies on.""" + def test_compile_batch_result_returns_typed_wire(self): + from file_processing.tasks import _compile_batch_result - def test_aggregator_can_read_general_path_shape(self): + # Minimum fake context — _compile_batch_result reads only + # ``metadata["result"]`` (with attribute access), the two + # skipped lists, and ``organization_context.organization_id``. + result = SimpleNamespace( + successful_files=4, failed_files=1, execution_time=2.5 + ) + context = SimpleNamespace( + metadata={ + "result": result, + "workflow_logger": None, # avoids the publish call + "skipped_already_completed": ["a.pdf"], + "skipped_active_duplicate": [], + }, + organization_context=SimpleNamespace(organization_id="org-prod"), + ) + + wire = _compile_batch_result(context) + + # Producer must emit the typed shape, not a hand-rolled dict. + assert wire["total_files"] == 6 # 4 + 1 + 1 skipped + assert wire["successful_files"] == 4 + assert wire["failed_files"] == 1 + assert wire["execution_time"] == pytest.approx(2.5) + assert wire["skipped_already_completed"] == 1 + assert wire["skipped_active_duplicate"] == 0 + assert wire["organization_id"] == "org-prod" + # Dataclass shape gains these defaults — strictly additive. + assert wire["file_results"] == [] + assert wire["errors"] == [] + + def test_process_single_file_api_already_completed_branch(self): + from file_processing.tasks import _process_single_file_api + from unstract.core.data_models import ExecutionStatus + + api_client = MagicMock() + api_client.get_workflow_file_execution.return_value = SimpleNamespace( + status=ExecutionStatus.COMPLETED.value, + result={"cached": "value"}, + metadata={"src": "history"}, + ) + wire = _process_single_file_api( + api_client=api_client, + file_data={"id": "fx-1", "file_name": "doc.pdf"}, + workflow_id="wf-1", + execution_id="exec-1", + pipeline_id=None, + use_file_history=True, + ) + + # Canonical per-file status vocabulary — not the legacy + # lowercase "completed". + assert wire["status"] == "Success" + assert wire["skipped"] == SkipReason.ALREADY_COMPLETED.value + assert wire["file_name"] == "doc.pdf" + assert wire["file_execution_id"] == "fx-1" + # Producer doesn't set ``error`` → __post_init__ keeps SUCCESS; + # serializer strips the None. + assert "error" not in wire + + +class TestRealConsumerTolerance: + """Drives the real ``aggregate_file_batch_results`` against the + new wire shape — proves the producer-consumer contract end-to-end. + """ + + def test_aggregator_consumes_general_path_shape(self): wire = BatchExecutionResult( total_files=5, successful_files=4, @@ -196,15 +296,34 @@ def test_aggregator_can_read_general_path_shape(self): skipped_active_duplicate=0, organization_id="org-1", ).to_dict() - # Mirrors aggregate_file_batch_results' ``.get()`` reads. - assert wire.get("total_files", 0) == 5 - assert wire.get("successful_files", 0) == 4 - assert wire.get("failed_files", 0) == 1 - assert wire.get("execution_time", 0) == 2.0 - # ``file_results`` is read (default []), and ``skipped_files`` - # is read but never written — same as legacy behaviour. - assert wire.get("file_results", []) == [] - assert wire.get("skipped_files", 0) == 0 + + aggregated = aggregate_file_batch_results([wire]) + + assert aggregated["total_files"] == 5 + assert aggregated["successful_files"] == 4 + assert aggregated["failed_files"] == 1 + assert aggregated["batches_processed"] == 1 + + def test_aggregator_consumes_multi_batch(self): + batches = [ + BatchExecutionResult( + total_files=3, + successful_files=3, + failed_files=0, + execution_time=1.0, + ).to_dict(), + BatchExecutionResult( + total_files=2, + successful_files=1, + failed_files=1, + execution_time=0.5, + ).to_dict(), + ] + aggregated = aggregate_file_batch_results(batches) + assert aggregated["total_files"] == 5 + assert aggregated["successful_files"] == 4 + assert aggregated["failed_files"] == 1 + assert aggregated["batches_processed"] == 2 if __name__ == "__main__": From 517deed5ba0349d2a77a370f3529c86e668b7e17 Mon Sep 17 00:00:00 2001 From: ali Date: Mon, 8 Jun 2026 14:42:21 +0530 Subject: [PATCH 3/5] UN-3513 [FIX] Address second-pass review (storage_result + lenient skipped + missing producer tests) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three findings from the second review round on PR #2020: * HIGH — storage_result silent data loss at batch boundary. The per-file dict-spread at tasks.py:1816 preserved storage_result on the immediate return, but the value was dropped when wrapped into BatchExecutionResult.file_results (from_dict didn't know the key). Promoted to a typed FileExecutionResult.storage_result: Any | None field; producer now emits via the constructor; from_dict reads it back. The round-trip preserves it end-to-end. * HIGH — strict SkipReason parsing would crash entire batches during rolling deploys if a newer producer ever emitted an unknown value. Added FileExecutionResult._parse_skipped, which catches ValueError + logs a warning + falls back to None. Standard "strict on emit, lenient on receive" posture for wire compat. * MEDIUM — TestProducerBinding only covered 2 of 5 producer branches. Added three more tests: - _process_single_file_api success branch (asserts storage_result survives the typed wire — would catch the dict-spread revert). - _process_single_file_api failure branch (asserts canonical "Failed" vocab — catches reverts to the legacy lowercase "failed"). - process_file_batch_api batch wrapper via task.apply() with an in-memory result_backend (asserts BatchExecutionResult shape + skipped_already_completed counter derived from SkipReason.ALREADY_COMPLETED.value). Strengthened the existing already-completed branch test to assert result_data + metadata propagation. Bug caught by the new batch-wrapper test: process_file_batch_api was missing execution_time on its BatchExecutionResult(...) call — BatchExecutionResult.execution_time is a required positional, so the API-path batch task would have crashed with TypeError on every run. Introduced batch_start_time = time.time() at task entry and pass execution_time = time.time() - batch_start_time. The new test would have caught this immediately at PR time; logging it here as the exact value of producer-binding coverage. Test count: 18 -> 21; all green. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../core/src/unstract/core/worker_models.py | 29 ++- workers/file_processing/tasks.py | 25 +-- workers/tests/test_chord_callback_boundary.py | 198 +++++++++++++++++- 3 files changed, 236 insertions(+), 16 deletions(-) diff --git a/unstract/core/src/unstract/core/worker_models.py b/unstract/core/src/unstract/core/worker_models.py index ed3ca89c36..f5c5523d8e 100644 --- a/unstract/core/src/unstract/core/worker_models.py +++ b/unstract/core/src/unstract/core/worker_models.py @@ -286,6 +286,13 @@ class FileExecutionResult: file_name: str | None = None result_data: Any | None = None skipped: SkipReason | None = None + # Storage backend's acknowledgement for a successfully-processed + # file. Forwarded verbatim from + # ``api_client.store_file_execution_result``. No in-tree consumer + # reads it today, but external integrations may inspect it on the + # wire — preserving it as a typed field rather than dict-spreading + # so it survives the BatchExecutionResult round-trip. + storage_result: Any | None = None def __post_init__(self) -> None: if self.error: @@ -312,6 +319,25 @@ def to_json(self) -> dict[str, Any]: """Convert to JSON-serializable dict for backward compatibility.""" return self.to_api_dict() + @staticmethod + def _parse_skipped(raw: Any) -> "SkipReason | None": + """Lenient ``SkipReason`` parser for the consumer side. + + Producer call sites are typed (constructor takes the enum, typos + fail at construction). On the consumer side we accept an unknown + wire value gracefully — a newer-producer / older-consumer + rolling-deploy must not crash the entire batch task on a value + the consumer doesn't recognise. Standard "strict on emit, + lenient on receive" posture. + """ + if not raw: + return None + try: + return SkipReason(raw) + except ValueError: + logger.warning("Unknown SkipReason on wire: %r; treating as None", raw) + return None + @classmethod def from_dict(cls, data: dict[str, Any]) -> "FileExecutionResult": """Create from dictionary (e.g., task result).""" @@ -332,7 +358,8 @@ def from_dict(cls, data: dict[str, Any]) -> "FileExecutionResult": file_size=data.get("file_size", 0), file_name=data.get("file_name"), result_data=data.get("result_data"), - skipped=SkipReason(data["skipped"]) if data.get("skipped") else None, + skipped=cls._parse_skipped(data.get("skipped")), + storage_result=data.get("storage_result"), ) def is_successful(self) -> bool: diff --git a/workers/file_processing/tasks.py b/workers/file_processing/tasks.py index d42aea277c..8b46680f4b 100644 --- a/workers/file_processing/tasks.py +++ b/workers/file_processing/tasks.py @@ -1564,6 +1564,7 @@ def process_file_batch_api( f"Processing API file batch {batch_id} with {len(created_files)} files" ) + batch_start_time = time.time() try: # Set organization context exactly like Django backend StateStore.set(Account.ORGANIZATION_ID, schema_name) @@ -1666,6 +1667,7 @@ def process_file_batch_api( total_files=len(file_results), successful_files=successful_files, failed_files=failed_files, + execution_time=time.time() - batch_start_time, file_results=[FileExecutionResult.from_dict(r) for r in file_results], skipped_already_completed=sum( 1 @@ -1810,20 +1812,15 @@ def _process_single_file_api( processing_time = time.time() - start_time - # ``storage_result`` isn't a field on FileExecutionResult yet - # (no consumer reads it today). Preserved via dict-spread so any - # external integration that does inspect it sees the same value. - result = { - **FileExecutionResult( - file=file_name, - file_execution_id=file_execution_id, - status=ApiDeploymentResultStatus.SUCCESS, - file_name=file_name, - processing_time=processing_time, - result_data=runner_result, - ).to_dict(), - "storage_result": storage_result, - } + result = FileExecutionResult( + file=file_name, + file_execution_id=file_execution_id, + status=ApiDeploymentResultStatus.SUCCESS, + file_name=file_name, + processing_time=processing_time, + result_data=runner_result, + storage_result=storage_result, + ).to_dict() logger.info(f"Successfully processed file: {file_name} in {processing_time:.2f}s") return result diff --git a/workers/tests/test_chord_callback_boundary.py b/workers/tests/test_chord_callback_boundary.py index 31855096db..e9d5d05417 100644 --- a/workers/tests/test_chord_callback_boundary.py +++ b/workers/tests/test_chord_callback_boundary.py @@ -24,7 +24,6 @@ from unittest.mock import MagicMock import pytest - from shared.processing.files.time_utils import aggregate_file_batch_results from unstract.core.worker_models import ( ApiDeploymentResultStatus, @@ -276,10 +275,207 @@ def test_process_single_file_api_already_completed_branch(self): assert wire["skipped"] == SkipReason.ALREADY_COMPLETED.value assert wire["file_name"] == "doc.pdf" assert wire["file_execution_id"] == "fx-1" + # Cached result + metadata propagate so the API consumer can + # short-circuit on the historical extraction. + assert wire["result_data"] == {"cached": "value"} + assert wire["metadata"] == {"src": "history"} # Producer doesn't set ``error`` → __post_init__ keeps SUCCESS; # serializer strips the None. assert "error" not in wire + def test_process_single_file_api_success_branch(self, monkeypatch): + """Drives the happy-path producer. Catches reverts to the + legacy dict-spread that dropped ``storage_result`` at the + chord-callback boundary (silent data loss on the API path). + """ + from file_processing import tasks as tasks_mod + from unstract.core.data_models import ExecutionStatus + + api_client = MagicMock() + # Not already-completed → fall through to the runner branch. + api_client.get_workflow_file_execution.return_value = SimpleNamespace( + status=ExecutionStatus.PENDING.value + ) + api_client.get_workflow_definition.return_value = {"id": "wf-1"} + api_client.get_file_content.return_value = b"hello world" + api_client.store_file_execution_result.return_value = { + "stored_at": "s3://bucket/key" + } + # Patch the runner-service call — its body shells out and isn't + # under test here. + monkeypatch.setattr( + tasks_mod, + "_call_runner_service", + lambda **kwargs: {"extracted": "value"}, + ) + + wire = tasks_mod._process_single_file_api( + api_client=api_client, + file_data={"id": "fx-ok", "file_name": "ok.pdf"}, + workflow_id="wf-1", + execution_id="exec-1", + pipeline_id=None, + use_file_history=False, + ) + + # Canonical per-file vocabulary — not the legacy lowercase + # ``"completed"`` the pre-typing producer used to emit. + assert wire["status"] == "Success" + assert wire["file_name"] == "ok.pdf" + assert wire["file_execution_id"] == "fx-ok" + assert wire["result_data"] == {"extracted": "value"} + # The whole point of this test: ``storage_result`` must survive + # the typed dataclass round-trip (UN-3513 finding). + assert wire["storage_result"] == {"stored_at": "s3://bucket/key"} + # No error → success branch keeps SUCCESS; ``error`` stripped. + assert "error" not in wire + assert "skipped" not in wire + + def test_process_single_file_api_failure_branch(self, monkeypatch): + """Drives the except-block producer. Catches reverts to the + legacy lowercase ``"failed"`` status string. + """ + from file_processing import tasks as tasks_mod + from unstract.core.data_models import ExecutionStatus + + api_client = MagicMock() + api_client.get_workflow_file_execution.return_value = SimpleNamespace( + status=ExecutionStatus.PENDING.value + ) + api_client.get_workflow_definition.return_value = {"id": "wf-1"} + api_client.get_file_content.return_value = b"payload" + # Runner blows up → producer falls into the except branch. + monkeypatch.setattr( + tasks_mod, + "_call_runner_service", + MagicMock(side_effect=RuntimeError("runner crashed")), + ) + + wire = tasks_mod._process_single_file_api( + api_client=api_client, + file_data={"id": "fx-bad", "file_name": "bad.pdf"}, + workflow_id="wf-1", + execution_id="exec-1", + pipeline_id=None, + use_file_history=False, + ) + + assert wire["status"] == "Failed" + assert wire["file_name"] == "bad.pdf" + assert wire["file_execution_id"] == "fx-bad" + assert wire["error"] == "runner crashed" + # Failure branch carries no result/storage payload. + assert "result_data" not in wire + assert "storage_result" not in wire + + def test_process_file_batch_api_batch_wrapper(self, monkeypatch): + """Drives the API-path batch wrapper. Catches reverts at the + ``BatchExecutionResult(...).to_dict()`` producer site + (file_processing/tasks.py around L1665). + """ + from file_processing import tasks as tasks_mod + from file_processing.worker import app as celery_app + + # Swap postgres result backend for in-memory so ``.apply()`` + # doesn't try to persist the eager task result. + original = { + "task_always_eager": celery_app.conf.task_always_eager, + "task_eager_propagates": celery_app.conf.task_eager_propagates, + "result_backend": celery_app.conf.result_backend, + } + celery_app.conf.update( + task_always_eager=True, + task_eager_propagates=True, + result_backend="cache+memory://", + ) + + # Stub per-file producer with two typed file results — one + # already-completed skip, one fresh success. + skipped_wire = FileExecutionResult( + file="a.pdf", + file_execution_id="fx-a", + status=ApiDeploymentResultStatus.SUCCESS, + file_name="a.pdf", + result_data={"cached": True}, + skipped=SkipReason.ALREADY_COMPLETED, + ).to_dict() + ok_wire = FileExecutionResult( + file="b.pdf", + file_execution_id="fx-b", + status=ApiDeploymentResultStatus.SUCCESS, + file_name="b.pdf", + result_data={"extracted": "value"}, + storage_result={"stored": "s3://k"}, + ).to_dict() + per_file_outputs = iter([skipped_wire, ok_wire]) + + monkeypatch.setattr( + tasks_mod, + "_process_single_file_api", + lambda **kwargs: next(per_file_outputs), + ) + # Neutralise organisation-level side effects. + monkeypatch.setattr( + tasks_mod.StateStore, "set", lambda *a, **k: None + ) + api_client_stub = MagicMock() + api_client_stub.get_workflow_execution.return_value = SimpleNamespace( + success=True, data={"execution": {"execution_log_id": None}} + ) + monkeypatch.setattr( + tasks_mod, "create_api_client", lambda schema_name: api_client_stub + ) + # ``WorkerWorkflowExecutionService`` is imported inline inside + # the batch task; patch the lazy import path. + cache_service = MagicMock() + cache_service.return_value.cache_api_result = MagicMock() + import shared.workflow.execution.service as service_mod + + monkeypatch.setattr( + service_mod, "WorkerWorkflowExecutionService", cache_service + ) + + try: + wire = tasks_mod.process_file_batch_api.apply( + args=[ + "org-1", # schema_name + "wf-1", # workflow_id + "exec-1", # execution_id + "batch-1", # batch_id + [ + {"id": "fx-a", "file_name": "a.pdf"}, + {"id": "fx-b", "file_name": "b.pdf"}, + ], # created_files + None, # pipeline_id + None, # execution_mode + False, # use_file_history + ] + ).get() + finally: + celery_app.conf.update(original) + + # Producer must emit the typed BatchExecutionResult shape. + assert wire["total_files"] == 2 + # Legacy API-path semantic: skipped files count as successful. + assert wire["successful_files"] == 2 + assert wire["failed_files"] == 0 + # Skip counter derived from SkipReason.ALREADY_COMPLETED.value + # — a typo here would silently zero the counter. + assert wire["skipped_already_completed"] == 1 + assert wire["organization_id"] == "org-1" + # ``execution_time`` is a required positional on the dataclass; + # omitting it would crash the task at the producer site. + assert "execution_time" in wire + assert wire["execution_time"] >= 0.0 + # file_results round-trips through FileExecutionResult, so + # ``storage_result`` survives to the batch boundary. + stored = [ + fr + for fr in wire["file_results"] + if fr.get("file_execution_id") == "fx-b" + ] + assert stored and stored[0]["storage_result"] == {"stored": "s3://k"} + class TestRealConsumerTolerance: """Drives the real ``aggregate_file_batch_results`` against the From 8559e39e076c2a7989eecc54077fbe8fa9d6c891 Mon Sep 17 00:00:00 2001 From: ali Date: Mon, 8 Jun 2026 15:51:29 +0530 Subject: [PATCH 4/5] UN-3513 [FIX] Symmetric None-stripping for nested file_results + deterministic callback healthcheck picker MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Greptile P2 #2 — None-stripping was asymmetric for nested FileExecutionResult objects. ``serialize_dataclass_to_dict`` only filters None at the outermost level, so a standalone ``FileExecutionResult.to_dict()`` would omit unset optional fields while ``batch.to_dict()["file_results"][i]`` would carry explicit ``"file_name": None`` etc. for the same input. A consumer doing ``"x" in result`` membership checks would behave differently depending on whether it read the standalone wire or the nested-in- batch wire — a real contract divergence. Fixed locally on ``BatchExecutionResult.to_dict()`` (not by touching the shared ``serialize_dataclass_to_dict`` infra): post-process ``wire["file_results"]`` to drop None-valued keys, mirroring the top-level strip. ``BatchExecutionResult.from_dict`` was already tolerant via ``.get(...)`` so the round-trip stays clean. Greptile P2 #1 (``status`` constructor parameter clobbered by ``__post_init__``) is the same pathology I flagged as BLOCKER #1 in the first review round — deferred to a separate ticket with the shared-infra dataclass redesign. Test coverage: extended the existing ``test_none_valued_optional_fields_stripped_from_wire`` to also assert nested symmetry — same test method, no new method added. This keeps the pytest collection profile stable (a separate test method would perturb celery's shared task-registry insertion order during pytest collection and amplify a pre-existing flake in ``test_callback_sanity.py``). Test infra fix (bundled because it would have flaked CI on this PR's HEAD): ``test_callback_sanity.TestEagerHealthcheckRoundTrip`` selected the healthcheck task via ``endswith(".healthcheck")`` against ``eager_app.tasks``. That registry is a shared celery global with at least 5 worker modules registering ``healthcheck`` (callback, executor, file_processing, log_consumer, scheduler). ``next(...)`` returned whichever was inserted first, which depends on pytest module-collection order across the whole suite. The test would assert ``worker_type == "callback"`` and intermittently get ``"executor"`` or ``"file_processing"`` instead — empirically a ~10% flake rate on this branch's HEAD, climbing to ~90% with any test-collection perturbation. Replaced with an exact-name lookup (``name == "callback.worker.healthcheck"``); 30/30 green across deterministic + randomised probes. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../core/src/unstract/core/worker_models.py | 19 +++++++++- workers/tests/test_callback_sanity.py | 12 ++++-- workers/tests/test_chord_callback_boundary.py | 38 +++++++++++++++---- 3 files changed, 57 insertions(+), 12 deletions(-) diff --git a/unstract/core/src/unstract/core/worker_models.py b/unstract/core/src/unstract/core/worker_models.py index f5c5523d8e..c1850051cc 100644 --- a/unstract/core/src/unstract/core/worker_models.py +++ b/unstract/core/src/unstract/core/worker_models.py @@ -403,8 +403,23 @@ def success_rate(self) -> float: return (self.successful_files / self.total_files) * 100 def to_dict(self) -> dict[str, Any]: - """Convert to dictionary for API response.""" - return serialize_dataclass_to_dict(self) + """Convert to dictionary for API response. + + Strips ``None`` from nested ``file_results`` so a per-file dict + read from ``batch["file_results"][i]`` has the same wire shape + as a standalone ``FileExecutionResult.to_dict()`` — both omit + unset optional fields. ``serialize_dataclass_to_dict`` only + strips ``None`` at the outer level, so without this fixup the + per-file dicts nested in the batch wire would carry explicit + ``"file_name": None`` etc. entries and break consumers that + rely on membership checks. + """ + wire = serialize_dataclass_to_dict(self) + wire["file_results"] = [ + {k: v for k, v in fr.items() if v is not None} + for fr in wire.get("file_results", []) + ] + return wire @classmethod def from_dict(cls, data: dict[str, Any]) -> "BatchExecutionResult": diff --git a/workers/tests/test_callback_sanity.py b/workers/tests/test_callback_sanity.py index 5d2b11fe75..0efd7519f2 100644 --- a/workers/tests/test_callback_sanity.py +++ b/workers/tests/test_callback_sanity.py @@ -142,10 +142,16 @@ class TestEagerHealthcheckRoundTrip: """ def test_eager_healthcheck_round_trip(self, eager_app): - # Find the healthcheck task; its module-qualified name varies. + # Find the callback worker's healthcheck task specifically. + # ``eager_app.tasks`` is a SHARED celery registry: every worker + # module that registers a ``healthcheck`` (executor, + # file_processing, log_consumer, ...) lands in the same dict. + # A bare ``endswith(".healthcheck")`` match returns whichever + # was inserted first, which depends on pytest collection / + # import order and produces flaky cross-worker results. healthcheck = next( t for name, t in eager_app.tasks.items() - if name.endswith(".healthcheck") or name == "healthcheck" + if name == "callback.worker.healthcheck" ) result = healthcheck.apply() @@ -163,7 +169,7 @@ def test_healthcheck_result_is_json_serializable(self, eager_app): healthcheck = next( t for name, t in eager_app.tasks.items() - if name.endswith(".healthcheck") or name == "healthcheck" + if name == "callback.worker.healthcheck" ) result = healthcheck.apply() diff --git a/workers/tests/test_chord_callback_boundary.py b/workers/tests/test_chord_callback_boundary.py index e9d5d05417..39fed5347e 100644 --- a/workers/tests/test_chord_callback_boundary.py +++ b/workers/tests/test_chord_callback_boundary.py @@ -185,18 +185,23 @@ def test_wire_is_json_safe(self): assert json.loads(json.dumps(wire)) == wire def test_none_valued_optional_fields_stripped_from_wire(self): - """``serialize_dataclass_to_dict`` drops ``None`` values. - - Documents the behaviour so consumers using membership checks - (``"x" in wire``) instead of ``.get(..., default)`` know what - to expect. Aliases default to ``None`` and only appear on the - wire when explicitly populated. + """``None`` defaults are stripped both standalone AND when + nested inside ``BatchExecutionResult.file_results``. + + ``serialize_dataclass_to_dict`` only filters ``None`` at the + outermost level; ``BatchExecutionResult.to_dict`` adds a + secondary filter so the per-file shape stays symmetric. Without + that fixup a consumer doing ``"x" in result`` membership checks + would behave differently for a standalone wire vs one read out + of ``batch["file_results"][i]``. """ minimal = FileExecutionResult( file="a.pdf", file_execution_id="fx", status=ApiDeploymentResultStatus.SUCCESS, ) + + # ---- Standalone wire ---- wire = minimal.to_dict() # Required fields and zero-valued numerics survive. assert wire["file"] == "a.pdf" @@ -204,9 +209,28 @@ def test_none_valued_optional_fields_stripped_from_wire(self): assert wire["processing_time"] == pytest.approx(0.0) assert wire["file_size"] == 0 # None defaults are dropped — not in the wire dict at all. - for absent in ("error", "result", "metadata", "file_name", "result_data", "skipped"): + absent_keys = ( + "error", "result", "metadata", "file_name", + "result_data", "skipped", "storage_result", + ) + for absent in absent_keys: assert absent not in wire, f"expected {absent!r} to be stripped when None" + # ---- Same shape when nested inside a batch ---- + batch_wire = BatchExecutionResult( + total_files=1, + successful_files=1, + failed_files=0, + execution_time=0.0, + file_results=[minimal], + ).to_dict() + nested_wire = batch_wire["file_results"][0] + for absent in absent_keys: + assert absent not in nested_wire, ( + f"{absent!r} leaked into nested file_results wire" + ) + assert sorted(wire.keys()) == sorted(nested_wire.keys()) + class TestProducerBinding: """Drives the real producer functions in ``file_processing.tasks``. From fb3500ef6e18f208584831d5b48e895344b1b4ae Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 9 Jun 2026 11:55:47 +0530 Subject: [PATCH 5/5] UN-3513 [FIX] Address vishnuszipstack review (7 real fixes + 1 docstring nit) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Seven of Vishnu's PR review findings addressed, all backward-compat with main-branch consumers. The three [Important] design-redesign findings (#1 status __post_init__, #2 alias-pair invariant, #3 to_api_dict/to_json dead code) are deferred to a follow-up shared-infra dataclass ticket because they would either fire warning noise on existing call sites (``worker_base.py:211/222``, ``worker_patterns.py:241`` pass wrong-enum status) or change the wire/cache contract — neither acceptable mid-flight while keeping zero regression on main. Changes in this commit are either: * Pure additive (test methods, docstrings, observability) * Or provably equivalent wire output (the typed-count refactor) So a rolling deploy where old workers and new workers run concurrently sees identical wire shapes and identical behaviour for all current valid data; the only observable differences are log content (better context on the existing warning) and the presence of a new opt-in classmethod that nothing currently calls. * **Vishnu #8 [Suggestion]** — ``SkipReason`` docstring claimed "StrEnum semantics" but the class is ``(str, Enum)``, not ``enum.StrEnum``. The two differ on ``__str__``. Rewrote the docstring to describe the actual behaviour. * **Vishnu #4a [Important — log context]** — ``_parse_skipped`` now accepts an optional ``file_execution_id`` kwarg that ``from_dict`` threads through. The warning emitted for unknown wire values now carries the file identifier, so a real rolling-deploy incident is debuggable rather than a context-free warning. Optional kwarg with default — any existing caller passing one positional arg still works. * **Vishnu #9 [Suggestion]** — added ``BatchExecutionResult.from_file_results(...)`` classmethod that derives counters from typed file results. Purely additive: no existing caller uses it; the constructor signature is unchanged so producers that need their own counter semantics keep working. * **Vishnu #11 [Suggestion]** — ``process_file_batch_api`` was computing ``skipped_already_completed`` by string-matching the wire dicts AFTER already calling ``from_dict`` on them. Refactored to count from the typed list (single ``from_dict`` pass, enum compare). Provably equivalent for all current wire data. * **Vishnu #4 [Important — test gap]** — added ``test_from_dict_unknown_skipped_is_lenient`` covering the one documented crash-prevention path. A regression to bare ``SkipReason(raw)`` would have re-introduced the rolling-deploy crash and kept every other test green. * **Vishnu #5 [Important — failure-aggregation gap]** — added ``test_process_file_batch_api_batch_wrapper_failure_aggregation`` that drives one success + one failure through the batch wrapper. The existing success-only test never exercised ``failed_files += 1``. * **Vishnu #6 [Important — populated round-trip gap]** — added ``test_round_trip_with_populated_file_results`` and ``test_from_file_results_derives_counters``. The existing ``BatchExecutionResult`` round-trip test used ``file_results=[]``, so the list-comprehension in ``from_dict`` that rebuilds nested ``FileExecutionResult`` objects was never executed with a populated list. * **Vishnu #13 [Suggestion]** — replaced hardcoded line reference in test docstring with a symbol reference. Deferred to follow-up shared-infra dataclass-redesign ticket: * #1 ``__post_init__`` status clobber — would emit warning noise on every existing wrong-enum call site * #2 alias-pair invariant — back-fill via __post_init__ would change the wire shape (file_name no longer None → no longer stripped at the top level) * #3 ``to_api_dict``/``to_json`` dead code — looks like a public SDK surface; changing the body could surprise external consumers * #7 recursive ``None``-strip in ``serialize_value`` — touches every dataclass in the codebase * #10 ``Any`` typing tightening — low value, mypy tightening could trip downstream * #12 producer redundant kwargs — depends on #2's reconciliation Tests: workers chord-callback boundary suite 21 -> 25; full workers suite 622 -> 627 (no new failures; 6 pre-existing baseline unchanged). Five deterministic-order runs of the full suite returned exactly 627 passed / 6 pre-existing failed — zero flakiness from this change. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../core/src/unstract/core/worker_models.py | 82 ++++++- workers/file_processing/tasks.py | 16 +- workers/tests/test_chord_callback_boundary.py | 225 +++++++++++++++++- 3 files changed, 310 insertions(+), 13 deletions(-) diff --git a/unstract/core/src/unstract/core/worker_models.py b/unstract/core/src/unstract/core/worker_models.py index c1850051cc..4f3e48d99b 100644 --- a/unstract/core/src/unstract/core/worker_models.py +++ b/unstract/core/src/unstract/core/worker_models.py @@ -117,9 +117,13 @@ class SkipReason(str, Enum): Closed vocabulary so typos at producer call sites fail at construction time rather than silently producing an - unrecognisable value on the wire. StrEnum semantics — members - serialise to their string value and compare equal to the - underlying string. + unrecognisable value on the wire. Subclasses ``str`` so members + compare equal to their underlying string value; wire serialisation + extracts ``.value`` via ``serialize_dataclass_to_dict``'s + ``isinstance(value, Enum)`` branch (this is *not* + :class:`enum.StrEnum`, which would also change ``__str__`` to + return the value — that distinction matters if any caller ever + does ``str(SkipReason.X)`` rather than ``SkipReason.X.value``). Values mirror the batch-level skip counters on :class:`BatchExecutionResult` (``skipped_already_completed`` / @@ -320,7 +324,9 @@ def to_json(self) -> dict[str, Any]: return self.to_api_dict() @staticmethod - def _parse_skipped(raw: Any) -> "SkipReason | None": + def _parse_skipped( + raw: Any, file_execution_id: str | None = None + ) -> "SkipReason | None": """Lenient ``SkipReason`` parser for the consumer side. Producer call sites are typed (constructor takes the enum, typos @@ -329,13 +335,23 @@ def _parse_skipped(raw: Any) -> "SkipReason | None": rolling-deploy must not crash the entire batch task on a value the consumer doesn't recognise. Standard "strict on emit, lenient on receive" posture. + + ``file_execution_id`` is optional but threaded through from + ``from_dict`` so the warning log carries file context — without + it, a debugger seeing the warning has no way to find which file + triggered the unknown value. """ if not raw: return None try: return SkipReason(raw) except ValueError: - logger.warning("Unknown SkipReason on wire: %r; treating as None", raw) + logger.warning( + "Unknown SkipReason on wire: %r (file_execution_id=%s); " + "treating as None", + raw, + file_execution_id, + ) return None @classmethod @@ -347,9 +363,10 @@ def from_dict(cls, data: dict[str, Any]) -> "FileExecutionResult": if data.get("error") else ApiDeploymentResultStatus.SUCCESS ) + file_execution_id = data.get("file_execution_id") return cls( file=data.get("file", ""), - file_execution_id=data.get("file_execution_id"), + file_execution_id=file_execution_id, status=status, error=data.get("error"), result=data.get("result"), @@ -358,7 +375,9 @@ def from_dict(cls, data: dict[str, Any]) -> "FileExecutionResult": file_size=data.get("file_size", 0), file_name=data.get("file_name"), result_data=data.get("result_data"), - skipped=cls._parse_skipped(data.get("skipped")), + skipped=cls._parse_skipped( + data.get("skipped"), file_execution_id=file_execution_id + ), storage_result=data.get("storage_result"), ) @@ -442,6 +461,55 @@ def from_dict(cls, data: dict[str, Any]) -> "BatchExecutionResult": organization_id=data.get("organization_id"), ) + @classmethod + def from_file_results( + cls, + file_results: list[FileExecutionResult], + *, + execution_time: float, + organization_id: str | None = None, + batch_id: str | None = None, + errors: list[str] | None = None, + ) -> "BatchExecutionResult": + """Build a batch result by deriving counters from typed file results. + + Counts ``successful_files`` / ``failed_files`` / ``skipped_*`` + from the file results themselves rather than letting the caller + pass them as parameters — removes the class of bug where a + producer's hand-rolled counters drift from the underlying + ``file_results`` (e.g. a string-match on the wire that misses a + new ``SkipReason`` member). Existing call sites that need to + keep their own counter semantics can keep using the + constructor directly; this is purely additive. + + Note that ``successful_files`` here matches + ``BatchExecutionResult.is_successful()`` semantics — a file with + ``skipped`` set is *also* considered successful (consistent with + the API-path producer's "skipped files count as successful" + rule). Callers that need a different split should compute it + themselves and use the constructor. + """ + successful = sum(1 for fr in file_results if fr.is_successful()) + failed = sum(1 for fr in file_results if not fr.is_successful()) + skipped_already_completed = sum( + 1 for fr in file_results if fr.skipped == SkipReason.ALREADY_COMPLETED + ) + skipped_active_duplicate = sum( + 1 for fr in file_results if fr.skipped == SkipReason.ACTIVE_DUPLICATE + ) + return cls( + total_files=len(file_results), + successful_files=successful, + failed_files=failed, + execution_time=execution_time, + file_results=file_results, + batch_id=batch_id, + errors=errors or [], + skipped_already_completed=skipped_already_completed, + skipped_active_duplicate=skipped_active_duplicate, + organization_id=organization_id, + ) + def add_file_result(self, file_result: FileExecutionResult): """Add a file execution result to the batch.""" self.file_results.append(file_result) diff --git a/workers/file_processing/tasks.py b/workers/file_processing/tasks.py index 8b46680f4b..3aa93296e7 100644 --- a/workers/file_processing/tasks.py +++ b/workers/file_processing/tasks.py @@ -1663,16 +1663,24 @@ def process_file_batch_api( # semantic). Separating the skipped count from the successful # count is deferred — would change consumer-visible counters # and is tracked separately. + # + # Build typed file results once and derive the skip counter + # from the typed enum rather than re-string-matching against + # the wire — a serialisation/vocab drift in + # ``SkipReason.ALREADY_COMPLETED.value`` would otherwise + # silently zero the count. Equivalent to the previous + # string-compare for current data. + typed_file_results = [FileExecutionResult.from_dict(r) for r in file_results] batch_result = BatchExecutionResult( - total_files=len(file_results), + total_files=len(typed_file_results), successful_files=successful_files, failed_files=failed_files, execution_time=time.time() - batch_start_time, - file_results=[FileExecutionResult.from_dict(r) for r in file_results], + file_results=typed_file_results, skipped_already_completed=sum( 1 - for r in file_results - if r.get("skipped") == SkipReason.ALREADY_COMPLETED.value + for fr in typed_file_results + if fr.skipped == SkipReason.ALREADY_COMPLETED ), organization_id=schema_name, ).to_dict() diff --git a/workers/tests/test_chord_callback_boundary.py b/workers/tests/test_chord_callback_boundary.py index 39fed5347e..524708a3e5 100644 --- a/workers/tests/test_chord_callback_boundary.py +++ b/workers/tests/test_chord_callback_boundary.py @@ -81,10 +81,102 @@ def test_round_trip_preserves_all_fields(self): ) assert round_tripped.organization_id == original.organization_id + def test_round_trip_with_populated_file_results(self): + """The existing round-trip test uses ``file_results=[]``, so the + list-comprehension in ``BatchExecutionResult.from_dict`` that + rebuilds nested ``FileExecutionResult`` objects is never + exercised. A regression that stored raw dicts instead would + otherwise keep every test green. + """ + fr_ok = FileExecutionResult( + file="a.pdf", + file_execution_id="fx-a", + status=ApiDeploymentResultStatus.SUCCESS, + file_name="a.pdf", + storage_result={"stored_at": "s3://k1"}, + ) + fr_failed = FileExecutionResult( + file="b.pdf", + file_execution_id="fx-b", + status=ApiDeploymentResultStatus.FAILED, + file_name="b.pdf", + error="boom", + ) + original = BatchExecutionResult( + total_files=2, + successful_files=1, + failed_files=1, + execution_time=1.0, + file_results=[fr_ok, fr_failed], + organization_id="org-1", + ) + round_tripped = BatchExecutionResult.from_dict(original.to_dict()) + + assert len(round_tripped.file_results) == 2 + # Reconstruction must yield typed objects, not bare dicts. + for fr in round_tripped.file_results: + assert isinstance(fr, FileExecutionResult) + # And every cross-the-wire field survives. + rt_ok = next( + fr for fr in round_tripped.file_results if fr.file_execution_id == "fx-a" + ) + rt_failed = next( + fr for fr in round_tripped.file_results if fr.file_execution_id == "fx-b" + ) + assert rt_ok.storage_result == {"stored_at": "s3://k1"} + assert rt_ok.status == ApiDeploymentResultStatus.SUCCESS + assert rt_failed.error == "boom" + assert rt_failed.status == ApiDeploymentResultStatus.FAILED + def test_wire_is_json_safe(self): wire = self._make().to_dict() assert json.loads(json.dumps(wire)) == wire + def test_from_file_results_derives_counters(self): + """``from_file_results`` derives counters from typed input so a + new ``SkipReason`` member can't silently zero a counter through + wire-vocab drift (the failure mode that the hand-rolled + ``string == SkipReason.X.value`` in ``tasks.py`` would have + and that the typed refactor in this PR eliminates). + """ + results = [ + FileExecutionResult( + file="a.pdf", + file_execution_id="fx-a", + status=ApiDeploymentResultStatus.SUCCESS, + ), + FileExecutionResult( + file="b.pdf", + file_execution_id="fx-b", + status=ApiDeploymentResultStatus.SUCCESS, + skipped=SkipReason.ALREADY_COMPLETED, + ), + FileExecutionResult( + file="c.pdf", + file_execution_id="fx-c", + status=ApiDeploymentResultStatus.SUCCESS, + skipped=SkipReason.ACTIVE_DUPLICATE, + ), + FileExecutionResult( + file="d.pdf", + file_execution_id="fx-d", + status=ApiDeploymentResultStatus.FAILED, + error="boom", + ), + ] + batch = BatchExecutionResult.from_file_results( + results, execution_time=2.0, organization_id="org-1" + ) + assert batch.total_files == 4 + # SUCCESS-status results are counted as successful regardless of + # ``skipped`` — matches the API-path "skipped counts as success" + # rule documented in ``process_file_batch_api``. + assert batch.successful_files == 3 + assert batch.failed_files == 1 + assert batch.skipped_already_completed == 1 + assert batch.skipped_active_duplicate == 1 + assert batch.organization_id == "org-1" + class TestFileExecutionResultWireShape: """API path (``process_file_batch_api`` returns this per-file shape).""" @@ -161,6 +253,37 @@ def test_round_trip_preserves_all_aliases(self): assert round_tripped.skipped == original.skipped assert round_tripped.status == original.status + def test_from_dict_unknown_skipped_is_lenient(self, caplog): + """``_parse_skipped`` is the one documented crash-prevention + path for rolling deploys (newer producer emits a future + ``SkipReason``, older consumer receives it). Without this test + a regression to bare ``SkipReason(raw)`` would re-introduce the + crash and every other test would stay green. + """ + import logging + + with caplog.at_level(logging.WARNING, logger="unstract.core.worker_models"): + result = FileExecutionResult.from_dict( + { + "file": "x.pdf", + "file_execution_id": "fx-future", + "skipped": "teleported_to_2030", + } + ) + + assert result.skipped is None, ( + "unknown skipped value must downgrade to None, not crash" + ) + # Log must include the unknown raw value and the file + # identifier — a context-free warning would be useless for + # debugging a real rolling-deploy incident. + assert any( + "Unknown SkipReason" in r.message + and "teleported_to_2030" in r.message + and "fx-future" in r.message + for r in caplog.records + ) + def test_active_duplicate_skip_reason_round_trips(self): # ``ACTIVE_DUPLICATE`` mirrors the batch-level # ``skipped_active_duplicate`` counter; no producer emits it @@ -394,8 +517,10 @@ def test_process_single_file_api_failure_branch(self, monkeypatch): def test_process_file_batch_api_batch_wrapper(self, monkeypatch): """Drives the API-path batch wrapper. Catches reverts at the - ``BatchExecutionResult(...).to_dict()`` producer site - (file_processing/tasks.py around L1665). + ``BatchExecutionResult(...).to_dict()`` producer site in + ``process_file_batch_api`` (symbol ref rather than a line + number so this docstring doesn't rot the moment ``tasks.py`` + is edited above the producer). """ from file_processing import tasks as tasks_mod from file_processing.worker import app as celery_app @@ -500,6 +625,102 @@ def test_process_file_batch_api_batch_wrapper(self, monkeypatch): ] assert stored and stored[0]["storage_result"] == {"stored": "s3://k"} + def test_process_file_batch_api_batch_wrapper_failure_aggregation( + self, monkeypatch + ): + """Drive the failure-aggregation path of the batch wrapper — + the success-only variant above never exercises ``failed_files += + 1`` (``tasks.py`` ~L1657) or the canonical ``"Failed"`` vocab on + the nested per-file wire. A revert to the legacy lowercase + ``"failed"`` for the failure branch would otherwise stay green. + """ + from file_processing import tasks as tasks_mod + from file_processing.worker import app as celery_app + + original = { + "task_always_eager": celery_app.conf.task_always_eager, + "task_eager_propagates": celery_app.conf.task_eager_propagates, + "result_backend": celery_app.conf.result_backend, + } + celery_app.conf.update( + task_always_eager=True, + task_eager_propagates=True, + result_backend="cache+memory://", + ) + + # One success + one failure — the failed leg must arrive at the + # batch boundary as ``status="Failed"`` carrying ``error``. + ok_wire = FileExecutionResult( + file="ok.pdf", + file_execution_id="fx-ok", + status=ApiDeploymentResultStatus.SUCCESS, + file_name="ok.pdf", + result_data={"extracted": "value"}, + ).to_dict() + bad_wire = FileExecutionResult( + file="bad.pdf", + file_execution_id="fx-bad", + status=ApiDeploymentResultStatus.FAILED, + file_name="bad.pdf", + error="boom", + ).to_dict() + per_file_outputs = iter([ok_wire, bad_wire]) + + monkeypatch.setattr( + tasks_mod, + "_process_single_file_api", + lambda **kwargs: next(per_file_outputs), + ) + monkeypatch.setattr(tasks_mod.StateStore, "set", lambda *a, **k: None) + api_client_stub = MagicMock() + api_client_stub.get_workflow_execution.return_value = SimpleNamespace( + success=True, data={"execution": {"execution_log_id": None}} + ) + monkeypatch.setattr( + tasks_mod, "create_api_client", lambda schema_name: api_client_stub + ) + cache_service = MagicMock() + cache_service.return_value.cache_api_result = MagicMock() + import shared.workflow.execution.service as service_mod + + monkeypatch.setattr( + service_mod, "WorkerWorkflowExecutionService", cache_service + ) + + try: + wire = tasks_mod.process_file_batch_api.apply( + args=[ + "org-1", + "wf-1", + "exec-1", + "batch-1", + [ + {"id": "fx-ok", "file_name": "ok.pdf"}, + {"id": "fx-bad", "file_name": "bad.pdf"}, + ], + None, + None, + False, + ] + ).get() + finally: + celery_app.conf.update(original) + + # Failure-aggregation counters at the batch level. + assert wire["total_files"] == 2 + assert wire["successful_files"] == 1 + assert wire["failed_files"] == 1 + # The failed per-file result must survive the round-trip and + # arrive with the canonical capitalised vocab. + failed = [ + fr + for fr in wire["file_results"] + if fr.get("file_execution_id") == "fx-bad" + ] + assert failed, "failed per-file result missing from batch wire" + assert failed[0]["status"] == "Failed" + assert failed[0]["error"] == "boom" + class TestRealConsumerTolerance: """Drives the real ``aggregate_file_batch_results`` against the