diff --git a/unstract/core/src/unstract/core/__init__.py b/unstract/core/src/unstract/core/__init__.py index deb84687f2..255673b939 100644 --- a/unstract/core/src/unstract/core/__init__.py +++ b/unstract/core/src/unstract/core/__init__.py @@ -43,6 +43,7 @@ PipelineStatus, PipelineStatusUpdateRequest, QueueName, + SkipReason, StatusMappings, TaskError, TaskExecutionContext, @@ -76,6 +77,7 @@ "WebhookResult", "FileExecutionResult", "BatchExecutionResult", + "SkipReason", "CallbackExecutionData", "WorkflowExecutionUpdateRequest", "PipelineStatusUpdateRequest", diff --git a/unstract/core/src/unstract/core/worker_models.py b/unstract/core/src/unstract/core/worker_models.py index 17da3557a2..4f3e48d99b 100644 --- a/unstract/core/src/unstract/core/worker_models.py +++ b/unstract/core/src/unstract/core/worker_models.py @@ -112,6 +112,32 @@ class ApiDeploymentResultStatus(str, Enum): FAILED = "Failed" +class SkipReason(str, Enum): + """Reasons a per-file execution can be skipped. + + Closed vocabulary so typos at producer call sites fail at + construction time rather than silently producing an + unrecognisable value on the wire. Subclasses ``str`` so members + compare equal to their underlying string value; wire serialisation + extracts ``.value`` via ``serialize_dataclass_to_dict``'s + ``isinstance(value, Enum)`` branch (this is *not* + :class:`enum.StrEnum`, which would also change ``__str__`` to + return the value — that distinction matters if any caller ever + does ``str(SkipReason.X)`` rather than ``SkipReason.X.value``). + + Values mirror the batch-level skip counters on + :class:`BatchExecutionResult` (``skipped_already_completed`` / + ``skipped_active_duplicate``) so per-file and batch-level + vocabulary stay in lockstep. + """ + + # WorkflowFileExecution.status was already COMPLETED — another + # worker finished this file before we got to it. + ALREADY_COMPLETED = "already_completed" + # Another worker is currently processing the same file content. + ACTIVE_DUPLICATE = "active_duplicate" + + class NotificationMethod(str, Enum): """Notification delivery methods.""" @@ -239,7 +265,14 @@ def from_dict(cls, data: dict[str, Any]) -> "FinalOutputResult": @dataclass class FileExecutionResult: - """Structured result for file execution tasks.""" + """Structured result for file execution tasks. + + The API-deployment chord path historically returned dicts with + ``file_name`` / ``result_data`` / ``skipped`` rather than the + canonical ``file`` / ``result`` shape. The optional aliases below + let producers populate either vocabulary without the consumer + needing to know which path produced the result. + """ file: str file_execution_id: str | None @@ -249,6 +282,21 @@ class FileExecutionResult: metadata: dict[str, Any] | None = None processing_time: float = 0.0 file_size: int = 0 + # TODO(UN-3516): remove these three legacy API-path aliases once + # consumers have migrated to the canonical ``file`` / ``result`` + # fields and to a typed skip-reason. Kept additive in the meantime + # so the chord-callback producer can preserve its current dict + # vocabulary without consumer breakage. + file_name: str | None = None + result_data: Any | None = None + skipped: SkipReason | None = None + # Storage backend's acknowledgement for a successfully-processed + # file. Forwarded verbatim from + # ``api_client.store_file_execution_result``. No in-tree consumer + # reads it today, but external integrations may inspect it on the + # wire — preserving it as a typed field rather than dict-spreading + # so it survives the BatchExecutionResult round-trip. + storage_result: Any | None = None def __post_init__(self) -> None: if self.error: @@ -275,6 +323,37 @@ def to_json(self) -> dict[str, Any]: """Convert to JSON-serializable dict for backward compatibility.""" return self.to_api_dict() + @staticmethod + def _parse_skipped( + raw: Any, file_execution_id: str | None = None + ) -> "SkipReason | None": + """Lenient ``SkipReason`` parser for the consumer side. + + Producer call sites are typed (constructor takes the enum, typos + fail at construction). On the consumer side we accept an unknown + wire value gracefully — a newer-producer / older-consumer + rolling-deploy must not crash the entire batch task on a value + the consumer doesn't recognise. Standard "strict on emit, + lenient on receive" posture. + + ``file_execution_id`` is optional but threaded through from + ``from_dict`` so the warning log carries file context — without + it, a debugger seeing the warning has no way to find which file + triggered the unknown value. + """ + if not raw: + return None + try: + return SkipReason(raw) + except ValueError: + logger.warning( + "Unknown SkipReason on wire: %r (file_execution_id=%s); " + "treating as None", + raw, + file_execution_id, + ) + return None + @classmethod def from_dict(cls, data: dict[str, Any]) -> "FileExecutionResult": """Create from dictionary (e.g., task result).""" @@ -284,15 +363,22 @@ def from_dict(cls, data: dict[str, Any]) -> "FileExecutionResult": if data.get("error") else ApiDeploymentResultStatus.SUCCESS ) + file_execution_id = data.get("file_execution_id") return cls( file=data.get("file", ""), - file_execution_id=data.get("file_execution_id"), + file_execution_id=file_execution_id, status=status, error=data.get("error"), result=data.get("result"), metadata=data.get("metadata"), processing_time=data.get("processing_time", 0.0), file_size=data.get("file_size", 0), + file_name=data.get("file_name"), + result_data=data.get("result_data"), + skipped=cls._parse_skipped( + data.get("skipped"), file_execution_id=file_execution_id + ), + storage_result=data.get("storage_result"), ) def is_successful(self) -> bool: @@ -306,7 +392,13 @@ def has_error(self) -> bool: @dataclass class BatchExecutionResult: - """Structured result for batch execution tasks.""" + """Structured result for batch execution tasks. + + The general workflow chord path additionally tracks skipped-file + sub-counts and the org context on the wire. The optional fields + below capture that vocabulary without changing existing field + semantics — strictly additive. + """ total_files: int successful_files: int @@ -315,6 +407,12 @@ class BatchExecutionResult: file_results: list[FileExecutionResult] = field(default_factory=list) batch_id: str | None = None errors: list[str] = field(default_factory=list) + # Optional general-path fields. Producers populate; consumers that + # don't know about them are unaffected (existing reads use ``.get()`` + # with defaults). + skipped_already_completed: int = 0 + skipped_active_duplicate: int = 0 + organization_id: str | None = None @property def success_rate(self) -> float: @@ -324,8 +422,23 @@ def success_rate(self) -> float: return (self.successful_files / self.total_files) * 100 def to_dict(self) -> dict[str, Any]: - """Convert to dictionary for API response.""" - return serialize_dataclass_to_dict(self) + """Convert to dictionary for API response. + + Strips ``None`` from nested ``file_results`` so a per-file dict + read from ``batch["file_results"][i]`` has the same wire shape + as a standalone ``FileExecutionResult.to_dict()`` — both omit + unset optional fields. ``serialize_dataclass_to_dict`` only + strips ``None`` at the outer level, so without this fixup the + per-file dicts nested in the batch wire would carry explicit + ``"file_name": None`` etc. entries and break consumers that + rely on membership checks. + """ + wire = serialize_dataclass_to_dict(self) + wire["file_results"] = [ + {k: v for k, v in fr.items() if v is not None} + for fr in wire.get("file_results", []) + ] + return wire @classmethod def from_dict(cls, data: dict[str, Any]) -> "BatchExecutionResult": @@ -343,6 +456,58 @@ def from_dict(cls, data: dict[str, Any]) -> "BatchExecutionResult": file_results=file_results, batch_id=data.get("batch_id"), errors=data.get("errors", []), + skipped_already_completed=data.get("skipped_already_completed", 0), + skipped_active_duplicate=data.get("skipped_active_duplicate", 0), + organization_id=data.get("organization_id"), + ) + + @classmethod + def from_file_results( + cls, + file_results: list[FileExecutionResult], + *, + execution_time: float, + organization_id: str | None = None, + batch_id: str | None = None, + errors: list[str] | None = None, + ) -> "BatchExecutionResult": + """Build a batch result by deriving counters from typed file results. + + Counts ``successful_files`` / ``failed_files`` / ``skipped_*`` + from the file results themselves rather than letting the caller + pass them as parameters — removes the class of bug where a + producer's hand-rolled counters drift from the underlying + ``file_results`` (e.g. a string-match on the wire that misses a + new ``SkipReason`` member). Existing call sites that need to + keep their own counter semantics can keep using the + constructor directly; this is purely additive. + + Note that ``successful_files`` here matches + ``BatchExecutionResult.is_successful()`` semantics — a file with + ``skipped`` set is *also* considered successful (consistent with + the API-path producer's "skipped files count as successful" + rule). Callers that need a different split should compute it + themselves and use the constructor. + """ + successful = sum(1 for fr in file_results if fr.is_successful()) + failed = sum(1 for fr in file_results if not fr.is_successful()) + skipped_already_completed = sum( + 1 for fr in file_results if fr.skipped == SkipReason.ALREADY_COMPLETED + ) + skipped_active_duplicate = sum( + 1 for fr in file_results if fr.skipped == SkipReason.ACTIVE_DUPLICATE + ) + return cls( + total_files=len(file_results), + successful_files=successful, + failed_files=failed, + execution_time=execution_time, + file_results=file_results, + batch_id=batch_id, + errors=errors or [], + skipped_already_completed=skipped_already_completed, + skipped_active_duplicate=skipped_active_duplicate, + organization_id=organization_id, ) def add_file_result(self, file_result: FileExecutionResult): diff --git a/workers/file_processing/tasks.py b/workers/file_processing/tasks.py index 614266376a..3aa93296e7 100644 --- a/workers/file_processing/tasks.py +++ b/workers/file_processing/tasks.py @@ -50,7 +50,13 @@ PreCreatedFileData, WorkerFileData, ) -from unstract.core.worker_models import FileProcessingResult +from unstract.core.worker_models import ( + ApiDeploymentResultStatus, + BatchExecutionResult, + FileExecutionResult, + FileProcessingResult, + SkipReason, +) logger = WorkerLogger.get_logger(__name__) @@ -896,21 +902,19 @@ def _compile_batch_result(context: WorkflowContextData) -> dict[str, Any]: except Exception as cleanup_error: logger.warning(f"Failed to cleanup StateStore context: {cleanup_error}") - # Return the final result matching Django backend format - # Note: Only active duplicates count as failures; already-completed do not - return { - "successful_files": result.successful_files, - "failed_files": result.failed_files, # Includes active duplicates (user error) - "total_files": result.successful_files - + result.failed_files - + len(skipped_already_completed), # Include all files in batch - "skipped_already_completed": len(skipped_already_completed), # Not a failure - "skipped_active_duplicate": len( - skipped_active_duplicate - ), # IS a failure (counted above) - "execution_time": result.execution_time, - "organization_id": context.organization_context.organization_id, - } + # Return the final result matching Django backend format. + # Note: only active duplicates count as failures; already-completed do not. + return BatchExecutionResult( + total_files=( + result.successful_files + result.failed_files + len(skipped_already_completed) + ), + successful_files=result.successful_files, + failed_files=result.failed_files, + execution_time=result.execution_time, + skipped_already_completed=len(skipped_already_completed), + skipped_active_duplicate=len(skipped_active_duplicate), + organization_id=context.organization_context.organization_id, + ).to_dict() # HELPER FUNCTIONS (originally part of the massive process_file_batch function) @@ -1560,6 +1564,7 @@ def process_file_batch_api( f"Processing API file batch {batch_id} with {len(created_files)} files" ) + batch_start_time = time.time() try: # Set organization context exactly like Django backend StateStore.set(Account.ORGANIZATION_ID, schema_name) @@ -1652,11 +1657,33 @@ def process_file_batch_api( else: successful_files += 1 - # Return result matching Django FileBatchResult structure - batch_result = { - "successful_files": successful_files, - "failed_files": failed_files, - } + # Return result matching Django FileBatchResult structure. + # Note: ``successful_files`` here counts every non-error file, + # INCLUDING skipped-already-completed ones (legacy API-path + # semantic). Separating the skipped count from the successful + # count is deferred — would change consumer-visible counters + # and is tracked separately. + # + # Build typed file results once and derive the skip counter + # from the typed enum rather than re-string-matching against + # the wire — a serialisation/vocab drift in + # ``SkipReason.ALREADY_COMPLETED.value`` would otherwise + # silently zero the count. Equivalent to the previous + # string-compare for current data. + typed_file_results = [FileExecutionResult.from_dict(r) for r in file_results] + batch_result = BatchExecutionResult( + total_files=len(typed_file_results), + successful_files=successful_files, + failed_files=failed_files, + execution_time=time.time() - batch_start_time, + file_results=typed_file_results, + skipped_already_completed=sum( + 1 + for fr in typed_file_results + if fr.skipped == SkipReason.ALREADY_COMPLETED + ), + organization_id=schema_name, + ).to_dict() logger.info(f"Successfully processed API file batch {batch_id}") return batch_result @@ -1703,15 +1730,16 @@ def _process_single_file_api( f"Skipping processing for execution_id: {execution_id}, " f"file_execution_id: {file_execution_id}" ) - return { - "file_execution_id": file_execution_id, - "file_name": file_name, - "status": "completed", - "processing_time": 0.0, - "result_data": getattr(workflow_file_execution, "result", None), - "metadata": getattr(workflow_file_execution, "metadata", None) or {}, - "skipped": "already_completed", - } + return FileExecutionResult( + file=file_name, + file_execution_id=file_execution_id, + status=ApiDeploymentResultStatus.SUCCESS, + file_name=file_name, + processing_time=0.0, + result_data=getattr(workflow_file_execution, "result", None), + metadata=getattr(workflow_file_execution, "metadata", None) or {}, + skipped=SkipReason.ALREADY_COMPLETED, + ).to_dict() except Exception as e: logger.exception( f"API path: Failed to validate completion status for {file_execution_id}: {e}. " @@ -1792,14 +1820,15 @@ def _process_single_file_api( processing_time = time.time() - start_time - result = { - "file_execution_id": file_execution_id, - "file_name": file_name, - "status": "completed", - "processing_time": processing_time, - "result_data": runner_result, - "storage_result": storage_result, - } + result = FileExecutionResult( + file=file_name, + file_execution_id=file_execution_id, + status=ApiDeploymentResultStatus.SUCCESS, + file_name=file_name, + processing_time=processing_time, + result_data=runner_result, + storage_result=storage_result, + ).to_dict() logger.info(f"Successfully processed file: {file_name} in {processing_time:.2f}s") return result @@ -1820,13 +1849,14 @@ def _process_single_file_api( except Exception: logger.exception("Failed to update file execution status") - return { - "file_execution_id": file_execution_id, - "file_name": file_name, - "status": "failed", - "processing_time": processing_time, - "error": str(e), - } + return FileExecutionResult( + file=file_name, + file_execution_id=file_execution_id, + status=ApiDeploymentResultStatus.FAILED, + file_name=file_name, + processing_time=processing_time, + error=str(e), + ).to_dict() def _check_file_history( diff --git a/workers/tests/test_chord_callback_boundary.py b/workers/tests/test_chord_callback_boundary.py new file mode 100644 index 0000000000..524708a3e5 --- /dev/null +++ b/workers/tests/test_chord_callback_boundary.py @@ -0,0 +1,771 @@ +"""Wire-shape characterisation for the chord-callback boundary. + +Locks the on-wire contract for the producer paths that feed +``process_batch_callback`` (general path) and ``process_batch_callback_api`` +(API path). Producers build typed dataclasses and serialise via +``.to_dict()``. + +Three layers of test: + +1. **Dataclass wire shape** — ``to_dict`` / ``from_dict`` round-trip + preserves every consumer-read field; JSON-safe. +2. **Producer binding** — drives the real producer functions + (``_compile_batch_result``, ``_process_single_file_api``). Catches + reverts at the producer site that a dataclass-only test would miss. +3. **Real-consumer tolerance** — drives the real + ``aggregate_file_batch_results`` consumer against the typed wire + shape. Catches mismatches the producer-only test would miss. +""" + +from __future__ import annotations + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from shared.processing.files.time_utils import aggregate_file_batch_results +from unstract.core.worker_models import ( + ApiDeploymentResultStatus, + BatchExecutionResult, + FileExecutionResult, + SkipReason, +) + + +class TestBatchExecutionResultWireShape: + """General path (``process_file_batch`` returns this shape).""" + + def _make(self) -> BatchExecutionResult: + return BatchExecutionResult( + total_files=10, + successful_files=7, + failed_files=2, + execution_time=12.5, + skipped_already_completed=1, + skipped_active_duplicate=0, + organization_id="org-test", + ) + + def test_wire_contains_consumer_read_fields(self): + wire = self._make().to_dict() + for key in ( + "total_files", + "successful_files", + "failed_files", + "execution_time", + "file_results", + ): + assert key in wire, f"consumer-read field missing: {key}" + + def test_wire_carries_extended_optional_fields(self): + wire = self._make().to_dict() + assert wire["skipped_already_completed"] == 1 + assert wire["skipped_active_duplicate"] == 0 + assert wire["organization_id"] == "org-test" + + def test_round_trip_preserves_all_fields(self): + original = self._make() + round_tripped = BatchExecutionResult.from_dict(original.to_dict()) + assert round_tripped.total_files == original.total_files + assert round_tripped.successful_files == original.successful_files + assert round_tripped.failed_files == original.failed_files + assert round_tripped.execution_time == pytest.approx(original.execution_time) + assert ( + round_tripped.skipped_already_completed + == original.skipped_already_completed + ) + assert ( + round_tripped.skipped_active_duplicate + == original.skipped_active_duplicate + ) + assert round_tripped.organization_id == original.organization_id + + def test_round_trip_with_populated_file_results(self): + """The existing round-trip test uses ``file_results=[]``, so the + list-comprehension in ``BatchExecutionResult.from_dict`` that + rebuilds nested ``FileExecutionResult`` objects is never + exercised. A regression that stored raw dicts instead would + otherwise keep every test green. + """ + fr_ok = FileExecutionResult( + file="a.pdf", + file_execution_id="fx-a", + status=ApiDeploymentResultStatus.SUCCESS, + file_name="a.pdf", + storage_result={"stored_at": "s3://k1"}, + ) + fr_failed = FileExecutionResult( + file="b.pdf", + file_execution_id="fx-b", + status=ApiDeploymentResultStatus.FAILED, + file_name="b.pdf", + error="boom", + ) + original = BatchExecutionResult( + total_files=2, + successful_files=1, + failed_files=1, + execution_time=1.0, + file_results=[fr_ok, fr_failed], + organization_id="org-1", + ) + round_tripped = BatchExecutionResult.from_dict(original.to_dict()) + + assert len(round_tripped.file_results) == 2 + # Reconstruction must yield typed objects, not bare dicts. + for fr in round_tripped.file_results: + assert isinstance(fr, FileExecutionResult) + # And every cross-the-wire field survives. + rt_ok = next( + fr for fr in round_tripped.file_results if fr.file_execution_id == "fx-a" + ) + rt_failed = next( + fr for fr in round_tripped.file_results if fr.file_execution_id == "fx-b" + ) + assert rt_ok.storage_result == {"stored_at": "s3://k1"} + assert rt_ok.status == ApiDeploymentResultStatus.SUCCESS + assert rt_failed.error == "boom" + assert rt_failed.status == ApiDeploymentResultStatus.FAILED + + def test_wire_is_json_safe(self): + wire = self._make().to_dict() + assert json.loads(json.dumps(wire)) == wire + + def test_from_file_results_derives_counters(self): + """``from_file_results`` derives counters from typed input so a + new ``SkipReason`` member can't silently zero a counter through + wire-vocab drift (the failure mode that the hand-rolled + ``string == SkipReason.X.value`` in ``tasks.py`` would have + and that the typed refactor in this PR eliminates). + """ + results = [ + FileExecutionResult( + file="a.pdf", + file_execution_id="fx-a", + status=ApiDeploymentResultStatus.SUCCESS, + ), + FileExecutionResult( + file="b.pdf", + file_execution_id="fx-b", + status=ApiDeploymentResultStatus.SUCCESS, + skipped=SkipReason.ALREADY_COMPLETED, + ), + FileExecutionResult( + file="c.pdf", + file_execution_id="fx-c", + status=ApiDeploymentResultStatus.SUCCESS, + skipped=SkipReason.ACTIVE_DUPLICATE, + ), + FileExecutionResult( + file="d.pdf", + file_execution_id="fx-d", + status=ApiDeploymentResultStatus.FAILED, + error="boom", + ), + ] + batch = BatchExecutionResult.from_file_results( + results, execution_time=2.0, organization_id="org-1" + ) + assert batch.total_files == 4 + # SUCCESS-status results are counted as successful regardless of + # ``skipped`` — matches the API-path "skipped counts as success" + # rule documented in ``process_file_batch_api``. + assert batch.successful_files == 3 + assert batch.failed_files == 1 + assert batch.skipped_already_completed == 1 + assert batch.skipped_active_duplicate == 1 + assert batch.organization_id == "org-1" + + +class TestFileExecutionResultWireShape: + """API path (``process_file_batch_api`` returns this per-file shape).""" + + def _make_success(self) -> FileExecutionResult: + return FileExecutionResult( + file="invoice.pdf", + file_execution_id="fx-1", + status=ApiDeploymentResultStatus.SUCCESS, + file_name="invoice.pdf", + processing_time=4.2, + result_data={"extracted": "value"}, + metadata={"source": "user_upload"}, + ) + + def _make_failure(self) -> FileExecutionResult: + return FileExecutionResult( + file="broken.pdf", + file_execution_id="fx-2", + status=ApiDeploymentResultStatus.FAILED, + file_name="broken.pdf", + processing_time=0.1, + error="extractor crashed", + ) + + def _make_skipped(self) -> FileExecutionResult: + return FileExecutionResult( + file="dup.pdf", + file_execution_id="fx-3", + status=ApiDeploymentResultStatus.SUCCESS, + file_name="dup.pdf", + skipped=SkipReason.ALREADY_COMPLETED, + result_data={"cached": True}, + ) + + def test_wire_carries_file_name_alias(self): + wire = self._make_success().to_dict() + assert wire["file_name"] == "invoice.pdf" + assert wire["file"] == "invoice.pdf" + + def test_wire_carries_result_data_alias(self): + wire = self._make_success().to_dict() + assert wire["result_data"] == {"extracted": "value"} + + def test_wire_carries_skipped_marker(self): + wire = self._make_skipped().to_dict() + assert wire["skipped"] == SkipReason.ALREADY_COMPLETED.value + + def test_success_status_uses_canonical_vocab(self): + wire = self._make_success().to_dict() + assert wire["status"] == "Success" + + def test_failure_status_uses_canonical_vocab(self): + wire = self._make_failure().to_dict() + assert wire["status"] == "Failed" + assert wire["error"] == "extractor crashed" + + def test_post_init_derives_status_from_error(self): + result = FileExecutionResult( + file="x", + file_execution_id="fx", + status=ApiDeploymentResultStatus.SUCCESS, + error="boom", + ) + assert result.status == ApiDeploymentResultStatus.FAILED + + def test_round_trip_preserves_all_aliases(self): + original = self._make_skipped() + round_tripped = FileExecutionResult.from_dict(original.to_dict()) + assert round_tripped.file == original.file + assert round_tripped.file_name == original.file_name + assert round_tripped.file_execution_id == original.file_execution_id + assert round_tripped.result_data == original.result_data + assert round_tripped.skipped == original.skipped + assert round_tripped.status == original.status + + def test_from_dict_unknown_skipped_is_lenient(self, caplog): + """``_parse_skipped`` is the one documented crash-prevention + path for rolling deploys (newer producer emits a future + ``SkipReason``, older consumer receives it). Without this test + a regression to bare ``SkipReason(raw)`` would re-introduce the + crash and every other test would stay green. + """ + import logging + + with caplog.at_level(logging.WARNING, logger="unstract.core.worker_models"): + result = FileExecutionResult.from_dict( + { + "file": "x.pdf", + "file_execution_id": "fx-future", + "skipped": "teleported_to_2030", + } + ) + + assert result.skipped is None, ( + "unknown skipped value must downgrade to None, not crash" + ) + # Log must include the unknown raw value and the file + # identifier — a context-free warning would be useless for + # debugging a real rolling-deploy incident. + assert any( + "Unknown SkipReason" in r.message + and "teleported_to_2030" in r.message + and "fx-future" in r.message + for r in caplog.records + ) + + def test_active_duplicate_skip_reason_round_trips(self): + # ``ACTIVE_DUPLICATE`` mirrors the batch-level + # ``skipped_active_duplicate`` counter; no producer emits it + # per-file today but the enum value must exist and round-trip + # cleanly so a future producer can pick it without re-typing + # the bare string. + original = FileExecutionResult( + file="dup.pdf", + file_execution_id="fx", + status=ApiDeploymentResultStatus.SUCCESS, + skipped=SkipReason.ACTIVE_DUPLICATE, + ) + wire = original.to_dict() + assert wire["skipped"] == SkipReason.ACTIVE_DUPLICATE.value + assert wire["skipped"] == "active_duplicate" + round_tripped = FileExecutionResult.from_dict(wire) + assert round_tripped.skipped == SkipReason.ACTIVE_DUPLICATE + + def test_wire_is_json_safe(self): + for builder in (self._make_success, self._make_failure, self._make_skipped): + wire = builder().to_dict() + assert json.loads(json.dumps(wire)) == wire + + def test_none_valued_optional_fields_stripped_from_wire(self): + """``None`` defaults are stripped both standalone AND when + nested inside ``BatchExecutionResult.file_results``. + + ``serialize_dataclass_to_dict`` only filters ``None`` at the + outermost level; ``BatchExecutionResult.to_dict`` adds a + secondary filter so the per-file shape stays symmetric. Without + that fixup a consumer doing ``"x" in result`` membership checks + would behave differently for a standalone wire vs one read out + of ``batch["file_results"][i]``. + """ + minimal = FileExecutionResult( + file="a.pdf", + file_execution_id="fx", + status=ApiDeploymentResultStatus.SUCCESS, + ) + + # ---- Standalone wire ---- + wire = minimal.to_dict() + # Required fields and zero-valued numerics survive. + assert wire["file"] == "a.pdf" + assert wire["status"] == "Success" + assert wire["processing_time"] == pytest.approx(0.0) + assert wire["file_size"] == 0 + # None defaults are dropped — not in the wire dict at all. + absent_keys = ( + "error", "result", "metadata", "file_name", + "result_data", "skipped", "storage_result", + ) + for absent in absent_keys: + assert absent not in wire, f"expected {absent!r} to be stripped when None" + + # ---- Same shape when nested inside a batch ---- + batch_wire = BatchExecutionResult( + total_files=1, + successful_files=1, + failed_files=0, + execution_time=0.0, + file_results=[minimal], + ).to_dict() + nested_wire = batch_wire["file_results"][0] + for absent in absent_keys: + assert absent not in nested_wire, ( + f"{absent!r} leaked into nested file_results wire" + ) + assert sorted(wire.keys()) == sorted(nested_wire.keys()) + + +class TestProducerBinding: + """Drives the real producer functions in ``file_processing.tasks``. + + A revert at any of these sites back to a hand-rolled dict (or to + the legacy lowercase status strings) keeps the *dataclass* tests + green — these tests catch that by asserting the actual wire shape + the chord callback receives from the producer. + """ + + def test_compile_batch_result_returns_typed_wire(self): + from file_processing.tasks import _compile_batch_result + + # Minimum fake context — _compile_batch_result reads only + # ``metadata["result"]`` (with attribute access), the two + # skipped lists, and ``organization_context.organization_id``. + result = SimpleNamespace( + successful_files=4, failed_files=1, execution_time=2.5 + ) + context = SimpleNamespace( + metadata={ + "result": result, + "workflow_logger": None, # avoids the publish call + "skipped_already_completed": ["a.pdf"], + "skipped_active_duplicate": [], + }, + organization_context=SimpleNamespace(organization_id="org-prod"), + ) + + wire = _compile_batch_result(context) + + # Producer must emit the typed shape, not a hand-rolled dict. + assert wire["total_files"] == 6 # 4 + 1 + 1 skipped + assert wire["successful_files"] == 4 + assert wire["failed_files"] == 1 + assert wire["execution_time"] == pytest.approx(2.5) + assert wire["skipped_already_completed"] == 1 + assert wire["skipped_active_duplicate"] == 0 + assert wire["organization_id"] == "org-prod" + # Dataclass shape gains these defaults — strictly additive. + assert wire["file_results"] == [] + assert wire["errors"] == [] + + def test_process_single_file_api_already_completed_branch(self): + from file_processing.tasks import _process_single_file_api + from unstract.core.data_models import ExecutionStatus + + api_client = MagicMock() + api_client.get_workflow_file_execution.return_value = SimpleNamespace( + status=ExecutionStatus.COMPLETED.value, + result={"cached": "value"}, + metadata={"src": "history"}, + ) + wire = _process_single_file_api( + api_client=api_client, + file_data={"id": "fx-1", "file_name": "doc.pdf"}, + workflow_id="wf-1", + execution_id="exec-1", + pipeline_id=None, + use_file_history=True, + ) + + # Canonical per-file status vocabulary — not the legacy + # lowercase "completed". + assert wire["status"] == "Success" + assert wire["skipped"] == SkipReason.ALREADY_COMPLETED.value + assert wire["file_name"] == "doc.pdf" + assert wire["file_execution_id"] == "fx-1" + # Cached result + metadata propagate so the API consumer can + # short-circuit on the historical extraction. + assert wire["result_data"] == {"cached": "value"} + assert wire["metadata"] == {"src": "history"} + # Producer doesn't set ``error`` → __post_init__ keeps SUCCESS; + # serializer strips the None. + assert "error" not in wire + + def test_process_single_file_api_success_branch(self, monkeypatch): + """Drives the happy-path producer. Catches reverts to the + legacy dict-spread that dropped ``storage_result`` at the + chord-callback boundary (silent data loss on the API path). + """ + from file_processing import tasks as tasks_mod + from unstract.core.data_models import ExecutionStatus + + api_client = MagicMock() + # Not already-completed → fall through to the runner branch. + api_client.get_workflow_file_execution.return_value = SimpleNamespace( + status=ExecutionStatus.PENDING.value + ) + api_client.get_workflow_definition.return_value = {"id": "wf-1"} + api_client.get_file_content.return_value = b"hello world" + api_client.store_file_execution_result.return_value = { + "stored_at": "s3://bucket/key" + } + # Patch the runner-service call — its body shells out and isn't + # under test here. + monkeypatch.setattr( + tasks_mod, + "_call_runner_service", + lambda **kwargs: {"extracted": "value"}, + ) + + wire = tasks_mod._process_single_file_api( + api_client=api_client, + file_data={"id": "fx-ok", "file_name": "ok.pdf"}, + workflow_id="wf-1", + execution_id="exec-1", + pipeline_id=None, + use_file_history=False, + ) + + # Canonical per-file vocabulary — not the legacy lowercase + # ``"completed"`` the pre-typing producer used to emit. + assert wire["status"] == "Success" + assert wire["file_name"] == "ok.pdf" + assert wire["file_execution_id"] == "fx-ok" + assert wire["result_data"] == {"extracted": "value"} + # The whole point of this test: ``storage_result`` must survive + # the typed dataclass round-trip (UN-3513 finding). + assert wire["storage_result"] == {"stored_at": "s3://bucket/key"} + # No error → success branch keeps SUCCESS; ``error`` stripped. + assert "error" not in wire + assert "skipped" not in wire + + def test_process_single_file_api_failure_branch(self, monkeypatch): + """Drives the except-block producer. Catches reverts to the + legacy lowercase ``"failed"`` status string. + """ + from file_processing import tasks as tasks_mod + from unstract.core.data_models import ExecutionStatus + + api_client = MagicMock() + api_client.get_workflow_file_execution.return_value = SimpleNamespace( + status=ExecutionStatus.PENDING.value + ) + api_client.get_workflow_definition.return_value = {"id": "wf-1"} + api_client.get_file_content.return_value = b"payload" + # Runner blows up → producer falls into the except branch. + monkeypatch.setattr( + tasks_mod, + "_call_runner_service", + MagicMock(side_effect=RuntimeError("runner crashed")), + ) + + wire = tasks_mod._process_single_file_api( + api_client=api_client, + file_data={"id": "fx-bad", "file_name": "bad.pdf"}, + workflow_id="wf-1", + execution_id="exec-1", + pipeline_id=None, + use_file_history=False, + ) + + assert wire["status"] == "Failed" + assert wire["file_name"] == "bad.pdf" + assert wire["file_execution_id"] == "fx-bad" + assert wire["error"] == "runner crashed" + # Failure branch carries no result/storage payload. + assert "result_data" not in wire + assert "storage_result" not in wire + + def test_process_file_batch_api_batch_wrapper(self, monkeypatch): + """Drives the API-path batch wrapper. Catches reverts at the + ``BatchExecutionResult(...).to_dict()`` producer site in + ``process_file_batch_api`` (symbol ref rather than a line + number so this docstring doesn't rot the moment ``tasks.py`` + is edited above the producer). + """ + from file_processing import tasks as tasks_mod + from file_processing.worker import app as celery_app + + # Swap postgres result backend for in-memory so ``.apply()`` + # doesn't try to persist the eager task result. + original = { + "task_always_eager": celery_app.conf.task_always_eager, + "task_eager_propagates": celery_app.conf.task_eager_propagates, + "result_backend": celery_app.conf.result_backend, + } + celery_app.conf.update( + task_always_eager=True, + task_eager_propagates=True, + result_backend="cache+memory://", + ) + + # Stub per-file producer with two typed file results — one + # already-completed skip, one fresh success. + skipped_wire = FileExecutionResult( + file="a.pdf", + file_execution_id="fx-a", + status=ApiDeploymentResultStatus.SUCCESS, + file_name="a.pdf", + result_data={"cached": True}, + skipped=SkipReason.ALREADY_COMPLETED, + ).to_dict() + ok_wire = FileExecutionResult( + file="b.pdf", + file_execution_id="fx-b", + status=ApiDeploymentResultStatus.SUCCESS, + file_name="b.pdf", + result_data={"extracted": "value"}, + storage_result={"stored": "s3://k"}, + ).to_dict() + per_file_outputs = iter([skipped_wire, ok_wire]) + + monkeypatch.setattr( + tasks_mod, + "_process_single_file_api", + lambda **kwargs: next(per_file_outputs), + ) + # Neutralise organisation-level side effects. + monkeypatch.setattr( + tasks_mod.StateStore, "set", lambda *a, **k: None + ) + api_client_stub = MagicMock() + api_client_stub.get_workflow_execution.return_value = SimpleNamespace( + success=True, data={"execution": {"execution_log_id": None}} + ) + monkeypatch.setattr( + tasks_mod, "create_api_client", lambda schema_name: api_client_stub + ) + # ``WorkerWorkflowExecutionService`` is imported inline inside + # the batch task; patch the lazy import path. + cache_service = MagicMock() + cache_service.return_value.cache_api_result = MagicMock() + import shared.workflow.execution.service as service_mod + + monkeypatch.setattr( + service_mod, "WorkerWorkflowExecutionService", cache_service + ) + + try: + wire = tasks_mod.process_file_batch_api.apply( + args=[ + "org-1", # schema_name + "wf-1", # workflow_id + "exec-1", # execution_id + "batch-1", # batch_id + [ + {"id": "fx-a", "file_name": "a.pdf"}, + {"id": "fx-b", "file_name": "b.pdf"}, + ], # created_files + None, # pipeline_id + None, # execution_mode + False, # use_file_history + ] + ).get() + finally: + celery_app.conf.update(original) + + # Producer must emit the typed BatchExecutionResult shape. + assert wire["total_files"] == 2 + # Legacy API-path semantic: skipped files count as successful. + assert wire["successful_files"] == 2 + assert wire["failed_files"] == 0 + # Skip counter derived from SkipReason.ALREADY_COMPLETED.value + # — a typo here would silently zero the counter. + assert wire["skipped_already_completed"] == 1 + assert wire["organization_id"] == "org-1" + # ``execution_time`` is a required positional on the dataclass; + # omitting it would crash the task at the producer site. + assert "execution_time" in wire + assert wire["execution_time"] >= 0.0 + # file_results round-trips through FileExecutionResult, so + # ``storage_result`` survives to the batch boundary. + stored = [ + fr + for fr in wire["file_results"] + if fr.get("file_execution_id") == "fx-b" + ] + assert stored and stored[0]["storage_result"] == {"stored": "s3://k"} + + def test_process_file_batch_api_batch_wrapper_failure_aggregation( + self, monkeypatch + ): + """Drive the failure-aggregation path of the batch wrapper — + the success-only variant above never exercises ``failed_files += + 1`` (``tasks.py`` ~L1657) or the canonical ``"Failed"`` vocab on + the nested per-file wire. A revert to the legacy lowercase + ``"failed"`` for the failure branch would otherwise stay green. + """ + from file_processing import tasks as tasks_mod + from file_processing.worker import app as celery_app + + original = { + "task_always_eager": celery_app.conf.task_always_eager, + "task_eager_propagates": celery_app.conf.task_eager_propagates, + "result_backend": celery_app.conf.result_backend, + } + celery_app.conf.update( + task_always_eager=True, + task_eager_propagates=True, + result_backend="cache+memory://", + ) + + # One success + one failure — the failed leg must arrive at the + # batch boundary as ``status="Failed"`` carrying ``error``. + ok_wire = FileExecutionResult( + file="ok.pdf", + file_execution_id="fx-ok", + status=ApiDeploymentResultStatus.SUCCESS, + file_name="ok.pdf", + result_data={"extracted": "value"}, + ).to_dict() + bad_wire = FileExecutionResult( + file="bad.pdf", + file_execution_id="fx-bad", + status=ApiDeploymentResultStatus.FAILED, + file_name="bad.pdf", + error="boom", + ).to_dict() + per_file_outputs = iter([ok_wire, bad_wire]) + + monkeypatch.setattr( + tasks_mod, + "_process_single_file_api", + lambda **kwargs: next(per_file_outputs), + ) + monkeypatch.setattr(tasks_mod.StateStore, "set", lambda *a, **k: None) + api_client_stub = MagicMock() + api_client_stub.get_workflow_execution.return_value = SimpleNamespace( + success=True, data={"execution": {"execution_log_id": None}} + ) + monkeypatch.setattr( + tasks_mod, "create_api_client", lambda schema_name: api_client_stub + ) + cache_service = MagicMock() + cache_service.return_value.cache_api_result = MagicMock() + import shared.workflow.execution.service as service_mod + + monkeypatch.setattr( + service_mod, "WorkerWorkflowExecutionService", cache_service + ) + + try: + wire = tasks_mod.process_file_batch_api.apply( + args=[ + "org-1", + "wf-1", + "exec-1", + "batch-1", + [ + {"id": "fx-ok", "file_name": "ok.pdf"}, + {"id": "fx-bad", "file_name": "bad.pdf"}, + ], + None, + None, + False, + ] + ).get() + finally: + celery_app.conf.update(original) + + # Failure-aggregation counters at the batch level. + assert wire["total_files"] == 2 + assert wire["successful_files"] == 1 + assert wire["failed_files"] == 1 + # The failed per-file result must survive the round-trip and + # arrive with the canonical capitalised vocab. + failed = [ + fr + for fr in wire["file_results"] + if fr.get("file_execution_id") == "fx-bad" + ] + assert failed, "failed per-file result missing from batch wire" + assert failed[0]["status"] == "Failed" + assert failed[0]["error"] == "boom" + + +class TestRealConsumerTolerance: + """Drives the real ``aggregate_file_batch_results`` against the + new wire shape — proves the producer-consumer contract end-to-end. + """ + + def test_aggregator_consumes_general_path_shape(self): + wire = BatchExecutionResult( + total_files=5, + successful_files=4, + failed_files=1, + execution_time=2.0, + skipped_already_completed=0, + skipped_active_duplicate=0, + organization_id="org-1", + ).to_dict() + + aggregated = aggregate_file_batch_results([wire]) + + assert aggregated["total_files"] == 5 + assert aggregated["successful_files"] == 4 + assert aggregated["failed_files"] == 1 + assert aggregated["batches_processed"] == 1 + + def test_aggregator_consumes_multi_batch(self): + batches = [ + BatchExecutionResult( + total_files=3, + successful_files=3, + failed_files=0, + execution_time=1.0, + ).to_dict(), + BatchExecutionResult( + total_files=2, + successful_files=1, + failed_files=1, + execution_time=0.5, + ).to_dict(), + ] + aggregated = aggregate_file_batch_results(batches) + assert aggregated["total_files"] == 5 + assert aggregated["successful_files"] == 4 + assert aggregated["failed_files"] == 1 + assert aggregated["batches_processed"] == 2 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])