Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backend/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,7 @@ OCR_OPENAI_MODEL=gpt-5-mini # OpenAI vision model for OCR
# JWT_TENANT_CLAIM=tenant_id # claim -> tenant
# JWT_ROLES_CLAIM=roles # claim -> roles
# JWT_GROUPS_CLAIM=groups # claim -> groups

# Figure OCR throughput/cost (E3)
# OCR_MAX_CONCURRENCY=4 # parallel region-OCR HTTP calls per document
# OCR_MAX_CALLS_PER_DOC=60 # per-doc OCR call cap (0 = unlimited); extra figures get no OCR
4 changes: 4 additions & 0 deletions backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class Settings(BaseSettings):
ocr_timeout: int = 120 # seconds for full page OCR
ocr_region_timeout: int = 60 # seconds for region OCR
ocr_max_retries: int = 3
# Figure-OCR throughput/cost controls (E3). Region OCR calls are I/O-bound
# HTTP requests, run with bounded concurrency; the per-doc cap bounds spend.
ocr_max_concurrency: int = 4 # parallel region-OCR HTTP calls per doc
ocr_max_calls_per_doc: int = 60 # 0 = unlimited; beyond it, figures get empty OCR
ocr_text_layer_fallback_enabled: bool = False # OCR low-quality pages even when text layer exists
ocr_text_layer_min_chars: int = 200 # Only OCR text-layer pages below this char count

Expand Down
175 changes: 105 additions & 70 deletions backend/app/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
"""

import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Optional

import fitz # PyMuPDF

from app.config import get_settings
from app.db.graph_models import Node, NodeType
from app.ocr import get_ocr_client, OCRClient
from .ids import compute_node_id, compute_text_hash
Expand Down Expand Up @@ -118,48 +120,109 @@ def create_figure_nodes(
page_extractor = PageExtractor()

nodes: List[Node] = []


settings = get_settings()
max_concurrency = max(1, settings.ocr_max_concurrency)
max_calls = settings.ocr_max_calls_per_doc # 0 == unlimited

# Open PDF
doc = fitz.open(stream=pdf_bytes, filetype="pdf")

try:
# Group figures by page for efficient processing
# Group figures by page (insertion order preserves document order).
figures_by_page = {}
for fig in figures:
if fig.page_no not in figures_by_page:
figures_by_page[fig.page_no] = []
figures_by_page[fig.page_no].append(fig)

# Process each page
figures_by_page.setdefault(fig.page_no, []).append(fig)

# --- Phase A (single-threaded): render region images in document order.
# PyMuPDF pages are NOT thread-safe, so all fitz access stays here; the
# per-doc spend cap is applied deterministically to the first N figures.
work = [] # ordered: {"fig", "fig_index", "image_bytes"}
ocr_calls_planned = 0
over_cap = 0
for page_no, page_figures in figures_by_page.items():
page_idx = page_no - 1 # 0-indexed

if page_idx < 0 or page_idx >= len(doc):
page_idx = page_no - 1
page = None
if 0 <= page_idx < len(doc):
page = doc[page_idx]
else:
logger.warning(f"Page {page_no} out of range for doc {doc_id}")
continue

page = doc[page_idx]

for fig_idx, fig in enumerate(page_figures):
node = _create_single_figure_node(
fig=fig,
page=page,
fig_index=fig_idx,
doc_id=doc_id,
version=version,
page_extractor=page_extractor,
ocr_client=ocr_client,
skip_ocr=skip_ocr
)

if node:
nodes.append(node)

image_bytes = None
needs_ocr = bool(fig.bbox and not skip_ocr and ocr_client and page is not None)
if needs_ocr and max_calls and ocr_calls_planned >= max_calls:
over_cap += 1 # beyond the spend cap -> no OCR (empty text)
needs_ocr = False
if needs_ocr:
try:
image_bytes = page_extractor.render_region_image(
page=page, bbox=fig.bbox, zoom=2.0,
)
ocr_calls_planned += 1
except Exception as e:
logger.warning(f"Region render failed for {fig.label or 'figure'}: {e}")
image_bytes = None
work.append({"fig": fig, "fig_index": fig_idx, "image_bytes": image_bytes})

# --- Phase B (bounded concurrency): OCR the rendered regions. ocr_region
# is a stateless, thread-safe HTTP call.
ocr_text_by_item = {} # id(work_item) -> ocr_text

def _ocr(item):
fig = item["fig"]
return ocr_client.ocr_region(
image_bytes=item["image_bytes"],
doc_id=doc_id,
page_no=fig.page_no,
region_type=fig.figure_type,
)

ocr_items = [w for w in work if w["image_bytes"] is not None]
if ocr_items:
workers = min(max_concurrency, len(ocr_items))
if workers > 1:
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {pool.submit(_ocr, w): w for w in ocr_items}
for fut in as_completed(futures):
w = futures[fut]
try:
ocr_text_by_item[id(w)] = fut.result() or ""
except Exception as e:
logger.warning(f"OCR failed for {w['fig'].label or 'figure'}: {e}")
ocr_text_by_item[id(w)] = ""
else:
for w in ocr_items:
try:
ocr_text_by_item[id(w)] = _ocr(w) or ""
except Exception as e:
logger.warning(f"OCR failed for {w['fig'].label or 'figure'}: {e}")
ocr_text_by_item[id(w)] = ""

# --- Phase C (single-threaded, document order): build nodes.
for w in work:
node = _build_figure_node(
fig=w["fig"],
fig_index=w["fig_index"],
doc_id=doc_id,
version=version,
ocr_text=ocr_text_by_item.get(id(w), ""),
)
if node:
nodes.append(node)

finally:
doc.close()

logger.info(f"Created {len(nodes)} figure/table nodes for doc {doc_id} v{version}")


if over_cap:
logger.warning(
"OCR spend cap (%d) reached for doc %s; %d figure(s) left un-OCR'd",
max_calls, doc_id, over_cap,
)
logger.info(
"Created %d figure/table nodes for doc %s v%d (%d OCR calls, concurrency<=%d)",
len(nodes), doc_id, version, ocr_calls_planned, max_concurrency,
)

return nodes


Expand Down Expand Up @@ -250,28 +313,25 @@ def create_figure_nodes_from_data(
return nodes


def _create_single_figure_node(
def _build_figure_node(
fig: FigureData,
page: fitz.Page,
fig_index: int,
doc_id: str,
version: int,
page_extractor: PageExtractor,
ocr_client: Optional[OCRClient],
skip_ocr: bool = False
ocr_text: str = "",
) -> Optional[Node]:
"""Create a single figure/table node.

"""Build a single figure/table node from pre-computed OCR text.

OCR (and the fitz rendering it needs) is performed by create_figure_nodes
before this is called, so this is pure, fitz-free, and thread-safe to build.

Args:
fig: FigureData
page: PyMuPDF page object
fig_index: Index of figure on this page
doc_id: Document ID
version: Document version
page_extractor: PageExtractor for image rendering
ocr_client: OCR client (can be None if skip_ocr=True)
skip_ocr: If True, skip OCR for this figure

ocr_text: OCR result for this region ("" if none/skipped/failed)

Returns:
Node object or None if creation fails
"""
Expand All @@ -297,32 +357,7 @@ def _create_single_figure_node(
node_type=fig.figure_type
)

# OCR the region if bbox available and OCR not skipped
ocr_text = ""
if fig.bbox and not skip_ocr and ocr_client:
try:
# Crop and render region
image_bytes = page_extractor.render_region_image(
page=page,
bbox=fig.bbox,
zoom=2.0
)

# Run OCR
ocr_text = ocr_client.ocr_region(
image_bytes=image_bytes,
doc_id=doc_id,
page_no=fig.page_no,
region_type=fig.figure_type
)
except Exception as e:
logger.warning(f"OCR failed for {fig.label or 'unknown'}: {e}")
elif skip_ocr:
logger.debug(f"OCR skipped for {fig.label or 'figure'}")
elif not fig.bbox:
logger.debug(f"No bbox for {fig.label or 'figure'}, skipping OCR")

# Build text_md: label + caption + OCR
# Build text_md: label + caption + OCR (ocr_text is pre-computed)
text_parts = []
if fig.label:
text_parts.append(f"**{fig.label}**")
Expand Down
108 changes: 108 additions & 0 deletions backend/tests/test_figure_ocr_batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""E3: bounded-concurrency figure OCR + per-doc spend cap.

Covers the pure node builder and the create_figure_nodes orchestration (with
fitz / PageExtractor / OCR client mocked, no real PDF or network).
"""

from unittest.mock import MagicMock, patch

import pytest

from app.config import Settings
from app.graph.figure_detector import FigureData
from app.graph.nodes import _build_figure_node, create_figure_nodes


def _fig(page_no, label, ftype="figure"):
return FigureData(
page_no=page_no,
label=label,
caption=f"cap-{label}",
bbox={"x0": 0, "y0": 0, "x1": 10, "y1": 10},
figure_type=ftype,
)


def test_build_figure_node_composition():
fig = _fig(1, "Figure 1")
node = _build_figure_node(fig, 0, "doc1", 1, ocr_text="OCRTEXT")
assert "Figure 1" in node.text_md
assert "cap-Figure 1" in node.text_md
assert "OCRTEXT" in node.text_md
assert node.meta["has_ocr"] is True

no_ocr = _build_figure_node(fig, 0, "doc1", 1, ocr_text="")
assert no_ocr.meta["has_ocr"] is False
assert "OCRTEXT" not in no_ocr.text_md


def _mock_env(monkeypatch, ocr, *, max_calls=60, max_concurrency=4):
monkeypatch.setattr(
"app.graph.nodes.get_settings",
lambda: Settings(ocr_max_calls_per_doc=max_calls, ocr_max_concurrency=max_concurrency),
)
mock_doc = MagicMock()
mock_doc.__len__.return_value = 10
mock_doc.__getitem__.return_value = MagicMock() # a page
monkeypatch.setattr("app.graph.nodes.fitz.open", lambda *a, **k: mock_doc)
pe = MagicMock()
pe.render_region_image.return_value = b"img-bytes"
monkeypatch.setattr("app.graph.nodes.PageExtractor", lambda: pe)


def test_order_preserved_and_each_figure_ocrd(monkeypatch):
ocr = MagicMock()
ocr.ocr_region.side_effect = lambda image_bytes, doc_id, page_no, region_type: f"ocr-{page_no}-{region_type}"
_mock_env(monkeypatch, ocr)

figs = [_fig(1, "Figure 1"), _fig(1, "Figure 2"), _fig(2, "Table 1", "table")]
out = create_figure_nodes(figs, b"%PDF", "doc1", 1, ocr_client=ocr, skip_ocr=False)

assert [n.label for n in out] == ["Figure 1", "Figure 2", "Table 1"] # document order
assert ocr.ocr_region.call_count == 3
assert "ocr-2-table" in out[2].text_md


def test_spend_cap_limits_ocr_deterministically(monkeypatch):
ocr = MagicMock()
ocr.ocr_region.side_effect = lambda **k: "OCR"
_mock_env(monkeypatch, ocr, max_calls=1)

figs = [_fig(1, "Figure 1"), _fig(1, "Figure 2"), _fig(1, "Figure 3")]
out = create_figure_nodes(figs, b"%PDF", "doc1", 1, ocr_client=ocr, skip_ocr=False)

assert ocr.ocr_region.call_count == 1 # only the first figure
assert len(out) == 3 # all nodes still built
assert out[0].meta["has_ocr"] is True
assert out[1].meta["has_ocr"] is False # over the cap -> empty OCR
assert out[2].meta["has_ocr"] is False


def test_single_ocr_failure_is_non_fatal(monkeypatch):
calls = {"n": 0}

def flaky(**k):
calls["n"] += 1
if calls["n"] == 2:
raise RuntimeError("ocr boom")
return "OCR"

ocr = MagicMock()
ocr.ocr_region.side_effect = flaky
_mock_env(monkeypatch, ocr, max_concurrency=1) # deterministic order for this test

figs = [_fig(1, "Figure 1"), _fig(1, "Figure 2"), _fig(1, "Figure 3")]
out = create_figure_nodes(figs, b"%PDF", "doc1", 1, ocr_client=ocr, skip_ocr=False)

assert len(out) == 3 # failure didn't drop the doc
assert out[1].meta["has_ocr"] is False # the failed one has empty OCR
assert out[0].meta["has_ocr"] is True


def test_skip_ocr_builds_nodes_without_calls(monkeypatch):
ocr = MagicMock()
_mock_env(monkeypatch, ocr)
figs = [_fig(1, "Figure 1"), _fig(2, "Table 1", "table")]
out = create_figure_nodes(figs, b"%PDF", "doc1", 1, ocr_client=ocr, skip_ocr=True)
assert len(out) == 2
assert ocr.ocr_region.call_count == 0
Loading