diff --git a/backend/.env.example b/backend/.env.example index f7d094e..a5cc2dd 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -158,3 +158,7 @@ OCR_OPENAI_MODEL=gpt-5-mini # OpenAI vision model for OCR # JWT_TENANT_CLAIM=tenant_id # claim -> tenant # JWT_ROLES_CLAIM=roles # claim -> roles # JWT_GROUPS_CLAIM=groups # claim -> groups + +# Figure OCR throughput/cost (E3) +# OCR_MAX_CONCURRENCY=4 # parallel region-OCR HTTP calls per document +# OCR_MAX_CALLS_PER_DOC=60 # per-doc OCR call cap (0 = unlimited); extra figures get no OCR diff --git a/backend/app/config.py b/backend/app/config.py index 7292c70..4d85aba 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -82,6 +82,10 @@ class Settings(BaseSettings): ocr_timeout: int = 120 # seconds for full page OCR ocr_region_timeout: int = 60 # seconds for region OCR ocr_max_retries: int = 3 + # Figure-OCR throughput/cost controls (E3). Region OCR calls are I/O-bound + # HTTP requests, run with bounded concurrency; the per-doc cap bounds spend. + ocr_max_concurrency: int = 4 # parallel region-OCR HTTP calls per doc + ocr_max_calls_per_doc: int = 60 # 0 = unlimited; beyond it, figures get empty OCR ocr_text_layer_fallback_enabled: bool = False # OCR low-quality pages even when text layer exists ocr_text_layer_min_chars: int = 200 # Only OCR text-layer pages below this char count diff --git a/backend/app/graph/nodes.py b/backend/app/graph/nodes.py index 4fd1243..97568fa 100644 --- a/backend/app/graph/nodes.py +++ b/backend/app/graph/nodes.py @@ -7,10 +7,12 @@ """ import logging +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Optional import fitz # PyMuPDF +from app.config import get_settings from app.db.graph_models import Node, NodeType from app.ocr import get_ocr_client, OCRClient from .ids import compute_node_id, compute_text_hash @@ -118,48 +120,109 @@ def create_figure_nodes( page_extractor = PageExtractor() nodes: List[Node] = [] - + + settings = get_settings() + max_concurrency = max(1, settings.ocr_max_concurrency) + max_calls = settings.ocr_max_calls_per_doc # 0 == unlimited + # Open PDF doc = fitz.open(stream=pdf_bytes, filetype="pdf") - + try: - # Group figures by page for efficient processing + # Group figures by page (insertion order preserves document order). figures_by_page = {} for fig in figures: - if fig.page_no not in figures_by_page: - figures_by_page[fig.page_no] = [] - figures_by_page[fig.page_no].append(fig) - - # Process each page + figures_by_page.setdefault(fig.page_no, []).append(fig) + + # --- Phase A (single-threaded): render region images in document order. + # PyMuPDF pages are NOT thread-safe, so all fitz access stays here; the + # per-doc spend cap is applied deterministically to the first N figures. + work = [] # ordered: {"fig", "fig_index", "image_bytes"} + ocr_calls_planned = 0 + over_cap = 0 for page_no, page_figures in figures_by_page.items(): - page_idx = page_no - 1 # 0-indexed - - if page_idx < 0 or page_idx >= len(doc): + page_idx = page_no - 1 + page = None + if 0 <= page_idx < len(doc): + page = doc[page_idx] + else: logger.warning(f"Page {page_no} out of range for doc {doc_id}") - continue - - page = doc[page_idx] - for fig_idx, fig in enumerate(page_figures): - node = _create_single_figure_node( - fig=fig, - page=page, - fig_index=fig_idx, - doc_id=doc_id, - version=version, - page_extractor=page_extractor, - ocr_client=ocr_client, - skip_ocr=skip_ocr - ) - - if node: - nodes.append(node) - + image_bytes = None + needs_ocr = bool(fig.bbox and not skip_ocr and ocr_client and page is not None) + if needs_ocr and max_calls and ocr_calls_planned >= max_calls: + over_cap += 1 # beyond the spend cap -> no OCR (empty text) + needs_ocr = False + if needs_ocr: + try: + image_bytes = page_extractor.render_region_image( + page=page, bbox=fig.bbox, zoom=2.0, + ) + ocr_calls_planned += 1 + except Exception as e: + logger.warning(f"Region render failed for {fig.label or 'figure'}: {e}") + image_bytes = None + work.append({"fig": fig, "fig_index": fig_idx, "image_bytes": image_bytes}) + + # --- Phase B (bounded concurrency): OCR the rendered regions. ocr_region + # is a stateless, thread-safe HTTP call. + ocr_text_by_item = {} # id(work_item) -> ocr_text + + def _ocr(item): + fig = item["fig"] + return ocr_client.ocr_region( + image_bytes=item["image_bytes"], + doc_id=doc_id, + page_no=fig.page_no, + region_type=fig.figure_type, + ) + + ocr_items = [w for w in work if w["image_bytes"] is not None] + if ocr_items: + workers = min(max_concurrency, len(ocr_items)) + if workers > 1: + with ThreadPoolExecutor(max_workers=workers) as pool: + futures = {pool.submit(_ocr, w): w for w in ocr_items} + for fut in as_completed(futures): + w = futures[fut] + try: + ocr_text_by_item[id(w)] = fut.result() or "" + except Exception as e: + logger.warning(f"OCR failed for {w['fig'].label or 'figure'}: {e}") + ocr_text_by_item[id(w)] = "" + else: + for w in ocr_items: + try: + ocr_text_by_item[id(w)] = _ocr(w) or "" + except Exception as e: + logger.warning(f"OCR failed for {w['fig'].label or 'figure'}: {e}") + ocr_text_by_item[id(w)] = "" + + # --- Phase C (single-threaded, document order): build nodes. + for w in work: + node = _build_figure_node( + fig=w["fig"], + fig_index=w["fig_index"], + doc_id=doc_id, + version=version, + ocr_text=ocr_text_by_item.get(id(w), ""), + ) + if node: + nodes.append(node) + finally: doc.close() - - logger.info(f"Created {len(nodes)} figure/table nodes for doc {doc_id} v{version}") - + + if over_cap: + logger.warning( + "OCR spend cap (%d) reached for doc %s; %d figure(s) left un-OCR'd", + max_calls, doc_id, over_cap, + ) + logger.info( + "Created %d figure/table nodes for doc %s v%d (%d OCR calls, concurrency<=%d)", + len(nodes), doc_id, version, ocr_calls_planned, max_concurrency, + ) + return nodes @@ -250,28 +313,25 @@ def create_figure_nodes_from_data( return nodes -def _create_single_figure_node( +def _build_figure_node( fig: FigureData, - page: fitz.Page, fig_index: int, doc_id: str, version: int, - page_extractor: PageExtractor, - ocr_client: Optional[OCRClient], - skip_ocr: bool = False + ocr_text: str = "", ) -> Optional[Node]: - """Create a single figure/table node. - + """Build a single figure/table node from pre-computed OCR text. + + OCR (and the fitz rendering it needs) is performed by create_figure_nodes + before this is called, so this is pure, fitz-free, and thread-safe to build. + Args: fig: FigureData - page: PyMuPDF page object fig_index: Index of figure on this page doc_id: Document ID version: Document version - page_extractor: PageExtractor for image rendering - ocr_client: OCR client (can be None if skip_ocr=True) - skip_ocr: If True, skip OCR for this figure - + ocr_text: OCR result for this region ("" if none/skipped/failed) + Returns: Node object or None if creation fails """ @@ -297,32 +357,7 @@ def _create_single_figure_node( node_type=fig.figure_type ) - # OCR the region if bbox available and OCR not skipped - ocr_text = "" - if fig.bbox and not skip_ocr and ocr_client: - try: - # Crop and render region - image_bytes = page_extractor.render_region_image( - page=page, - bbox=fig.bbox, - zoom=2.0 - ) - - # Run OCR - ocr_text = ocr_client.ocr_region( - image_bytes=image_bytes, - doc_id=doc_id, - page_no=fig.page_no, - region_type=fig.figure_type - ) - except Exception as e: - logger.warning(f"OCR failed for {fig.label or 'unknown'}: {e}") - elif skip_ocr: - logger.debug(f"OCR skipped for {fig.label or 'figure'}") - elif not fig.bbox: - logger.debug(f"No bbox for {fig.label or 'figure'}, skipping OCR") - - # Build text_md: label + caption + OCR + # Build text_md: label + caption + OCR (ocr_text is pre-computed) text_parts = [] if fig.label: text_parts.append(f"**{fig.label}**") diff --git a/backend/tests/test_figure_ocr_batching.py b/backend/tests/test_figure_ocr_batching.py new file mode 100644 index 0000000..fc48acc --- /dev/null +++ b/backend/tests/test_figure_ocr_batching.py @@ -0,0 +1,108 @@ +"""E3: bounded-concurrency figure OCR + per-doc spend cap. + +Covers the pure node builder and the create_figure_nodes orchestration (with +fitz / PageExtractor / OCR client mocked, no real PDF or network). +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from app.config import Settings +from app.graph.figure_detector import FigureData +from app.graph.nodes import _build_figure_node, create_figure_nodes + + +def _fig(page_no, label, ftype="figure"): + return FigureData( + page_no=page_no, + label=label, + caption=f"cap-{label}", + bbox={"x0": 0, "y0": 0, "x1": 10, "y1": 10}, + figure_type=ftype, + ) + + +def test_build_figure_node_composition(): + fig = _fig(1, "Figure 1") + node = _build_figure_node(fig, 0, "doc1", 1, ocr_text="OCRTEXT") + assert "Figure 1" in node.text_md + assert "cap-Figure 1" in node.text_md + assert "OCRTEXT" in node.text_md + assert node.meta["has_ocr"] is True + + no_ocr = _build_figure_node(fig, 0, "doc1", 1, ocr_text="") + assert no_ocr.meta["has_ocr"] is False + assert "OCRTEXT" not in no_ocr.text_md + + +def _mock_env(monkeypatch, ocr, *, max_calls=60, max_concurrency=4): + monkeypatch.setattr( + "app.graph.nodes.get_settings", + lambda: Settings(ocr_max_calls_per_doc=max_calls, ocr_max_concurrency=max_concurrency), + ) + mock_doc = MagicMock() + mock_doc.__len__.return_value = 10 + mock_doc.__getitem__.return_value = MagicMock() # a page + monkeypatch.setattr("app.graph.nodes.fitz.open", lambda *a, **k: mock_doc) + pe = MagicMock() + pe.render_region_image.return_value = b"img-bytes" + monkeypatch.setattr("app.graph.nodes.PageExtractor", lambda: pe) + + +def test_order_preserved_and_each_figure_ocrd(monkeypatch): + ocr = MagicMock() + ocr.ocr_region.side_effect = lambda image_bytes, doc_id, page_no, region_type: f"ocr-{page_no}-{region_type}" + _mock_env(monkeypatch, ocr) + + figs = [_fig(1, "Figure 1"), _fig(1, "Figure 2"), _fig(2, "Table 1", "table")] + out = create_figure_nodes(figs, b"%PDF", "doc1", 1, ocr_client=ocr, skip_ocr=False) + + assert [n.label for n in out] == ["Figure 1", "Figure 2", "Table 1"] # document order + assert ocr.ocr_region.call_count == 3 + assert "ocr-2-table" in out[2].text_md + + +def test_spend_cap_limits_ocr_deterministically(monkeypatch): + ocr = MagicMock() + ocr.ocr_region.side_effect = lambda **k: "OCR" + _mock_env(monkeypatch, ocr, max_calls=1) + + figs = [_fig(1, "Figure 1"), _fig(1, "Figure 2"), _fig(1, "Figure 3")] + out = create_figure_nodes(figs, b"%PDF", "doc1", 1, ocr_client=ocr, skip_ocr=False) + + assert ocr.ocr_region.call_count == 1 # only the first figure + assert len(out) == 3 # all nodes still built + assert out[0].meta["has_ocr"] is True + assert out[1].meta["has_ocr"] is False # over the cap -> empty OCR + assert out[2].meta["has_ocr"] is False + + +def test_single_ocr_failure_is_non_fatal(monkeypatch): + calls = {"n": 0} + + def flaky(**k): + calls["n"] += 1 + if calls["n"] == 2: + raise RuntimeError("ocr boom") + return "OCR" + + ocr = MagicMock() + ocr.ocr_region.side_effect = flaky + _mock_env(monkeypatch, ocr, max_concurrency=1) # deterministic order for this test + + figs = [_fig(1, "Figure 1"), _fig(1, "Figure 2"), _fig(1, "Figure 3")] + out = create_figure_nodes(figs, b"%PDF", "doc1", 1, ocr_client=ocr, skip_ocr=False) + + assert len(out) == 3 # failure didn't drop the doc + assert out[1].meta["has_ocr"] is False # the failed one has empty OCR + assert out[0].meta["has_ocr"] is True + + +def test_skip_ocr_builds_nodes_without_calls(monkeypatch): + ocr = MagicMock() + _mock_env(monkeypatch, ocr) + figs = [_fig(1, "Figure 1"), _fig(2, "Table 1", "table")] + out = create_figure_nodes(figs, b"%PDF", "doc1", 1, ocr_client=ocr, skip_ocr=True) + assert len(out) == 2 + assert ocr.ocr_region.call_count == 0