diff --git a/src/ix/ocr/client.py b/src/ix/ocr/client.py index ca84185..d5ae2b9 100644 --- a/src/ix/ocr/client.py +++ b/src/ix/ocr/client.py @@ -5,11 +5,19 @@ method satisfies the Protocol. :class:`~ix.pipeline.ocr_step.OCRStep` depends on the Protocol, not a concrete class, so swapping engines (``FakeOCRClient`` in tests, ``SuryaOCRClient`` in prod) stays a wiring change at the app factory. + +Per-page source location (``files`` + ``page_metadata``) flows in as +optional kwargs: fakes ignore them; the real +:class:`~ix.ocr.surya_client.SuryaOCRClient` uses them to render each +page's pixels back from disk. Keeping these optional lets unit tests stay +pages-only while production wiring (Task 4.3) plumbs through the real +filesystem handles. """ from __future__ import annotations -from typing import Protocol, runtime_checkable +from pathlib import Path +from typing import Any, Protocol, runtime_checkable from ix.contracts import OCRResult, Page @@ -24,8 +32,18 @@ class OCRClient(Protocol): per input page (in the same order). """ - async def ocr(self, pages: list[Page]) -> OCRResult: - """Run OCR over the input pages; return the structured result.""" + async def ocr( + self, + pages: list[Page], + *, + files: list[tuple[Path, str]] | None = None, + page_metadata: list[Any] | None = None, + ) -> OCRResult: + """Run OCR over the input pages; return the structured result. + + ``files`` and ``page_metadata`` are optional for hermetic tests; + real engines that need to re-render from disk read them. + """ ... diff --git a/src/ix/ocr/fake.py b/src/ix/ocr/fake.py index ca811d8..2e417b9 100644 --- a/src/ix/ocr/fake.py +++ b/src/ix/ocr/fake.py @@ -30,8 +30,17 @@ class FakeOCRClient: self._canned = canned self._raise_on_call = raise_on_call - async def ocr(self, pages: list[Page]) -> OCRResult: - """Return the canned result or raise the configured error.""" + async def ocr( + self, + pages: list[Page], + **_kwargs: object, + ) -> OCRResult: + """Return the canned result or raise the configured error. + + Accepts (and ignores) any keyword args the production Protocol may + carry — keeps the fake swappable for :class:`SuryaOCRClient` at + call sites that pass ``files`` / ``page_metadata``. + """ if self._raise_on_call is not None: raise self._raise_on_call return self._canned diff --git a/src/ix/ocr/surya_client.py b/src/ix/ocr/surya_client.py new file mode 100644 index 0000000..b277be9 --- /dev/null +++ b/src/ix/ocr/surya_client.py @@ -0,0 +1,235 @@ +"""SuryaOCRClient — real :class:`OCRClient` backed by ``surya-ocr``. + +Per spec §6.2: the MVP OCR engine. Runs Surya's detection + recognition +predictors over per-page PIL images rendered from the downloaded sources +(PDFs via PyMuPDF, images via Pillow). + +Design choices: + +* **Lazy model loading.** ``__init__`` is cheap; the heavy predictors are + built on first :meth:`ocr` / :meth:`selfcheck` / explicit :meth:`warm_up`. + This keeps FastAPI's lifespan predictable — ops can decide whether to + pay the load cost up front or on first request. +* **Device is Surya's default.** CUDA on the prod box, MPS on M-series Macs. + We deliberately don't pin. +* **No text-token reuse from PyMuPDF.** The cross-check against Paperless' + Tesseract output (ReliabilityStep's ``text_agreement``) is only meaningful + with a truly independent OCR pass, so we always render-and-recognize + even for PDFs that carry embedded text. + +The ``surya-ocr`` package pulls torch + heavy model deps, so it's kept +behind the ``[ocr]`` extra. All Surya imports are deferred into +:meth:`warm_up` so running the unit tests (which patch the predictors) +doesn't require the package to be installed. +""" + +from __future__ import annotations + +import asyncio +import contextlib +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal + +from ix.contracts import Line, OCRDetails, OCRResult, Page +from ix.segmentation import PageMetadata + +if TYPE_CHECKING: # pragma: no cover + from PIL import Image as PILImage + + +class SuryaOCRClient: + """Surya-backed OCR engine. + + Attributes are created lazily by :meth:`warm_up`. The unit tests inject + mocks directly onto ``_recognition_predictor`` / ``_detection_predictor`` + to avoid the Surya import chain. + """ + + def __init__(self) -> None: + self._recognition_predictor: Any = None + self._detection_predictor: Any = None + + def warm_up(self) -> None: + """Load the detection + recognition predictors. Idempotent. + + Called automatically on the first :meth:`ocr` / :meth:`selfcheck`, + or explicitly from the app lifespan to front-load the cost. + """ + if ( + self._recognition_predictor is not None + and self._detection_predictor is not None + ): + return + + # Deferred imports: only reachable when the optional [ocr] extra is + # installed. Keeping them inside the method so base-install unit + # tests (which patch the predictors) don't need surya on sys.path. + from surya.detection import DetectionPredictor # type: ignore[import-not-found] + from surya.foundation import FoundationPredictor # type: ignore[import-not-found] + from surya.recognition import RecognitionPredictor # type: ignore[import-not-found] + + foundation = FoundationPredictor() + self._recognition_predictor = RecognitionPredictor(foundation) + self._detection_predictor = DetectionPredictor() + + async def ocr( + self, + pages: list[Page], + *, + files: list[tuple[Path, str]] | None = None, + page_metadata: list[Any] | None = None, + ) -> OCRResult: + """Render each input page, run Surya, translate back to contracts.""" + self.warm_up() + + images = self._render_pages(pages, files, page_metadata) + + # Surya is blocking — run it off the event loop. + loop = asyncio.get_running_loop() + surya_results = await loop.run_in_executor( + None, self._run_recognition, images + ) + + out_pages: list[Page] = [] + all_text_fragments: list[str] = [] + for input_page, surya_result in zip(pages, surya_results, strict=True): + lines: list[Line] = [] + for tl in getattr(surya_result, "text_lines", []) or []: + flat = self._flatten_polygon(getattr(tl, "polygon", None)) + text = getattr(tl, "text", None) + lines.append(Line(text=text, bounding_box=flat)) + if text: + all_text_fragments.append(text) + out_pages.append( + Page( + page_no=input_page.page_no, + width=input_page.width, + height=input_page.height, + angle=input_page.angle, + unit=input_page.unit, + lines=lines, + ) + ) + + details = OCRDetails( + text="\n".join(all_text_fragments) if all_text_fragments else None, + pages=out_pages, + ) + return OCRResult(result=details, meta_data={"engine": "surya"}) + + async def selfcheck(self) -> Literal["ok", "fail"]: + """Run the predictors on a 1x1 image to confirm the stack works.""" + try: + self.warm_up() + except Exception: + return "fail" + + try: + from PIL import Image as PILImageRuntime + + img = PILImageRuntime.new("RGB", (1, 1), color="white") + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self._run_recognition, [img]) + except Exception: + return "fail" + return "ok" + + def _run_recognition(self, images: list[PILImage.Image]) -> list[Any]: + """Invoke the recognition predictor. Kept tiny for threadpool offload.""" + return list( + self._recognition_predictor( + images, det_predictor=self._detection_predictor + ) + ) + + def _render_pages( + self, + pages: list[Page], + files: list[tuple[Path, str]] | None, + page_metadata: list[Any] | None, + ) -> list[PILImage.Image]: + """Render each input :class:`Page` to a PIL image. + + We walk pages + page_metadata in lockstep so we know which source + file each page came from and (for PDFs) what page-index to render. + Text-only pages (``file_index is None``) get a blank 1x1 placeholder + so Surya returns an empty result and downstream code still gets one + entry per input page. + """ + from PIL import Image as PILImageRuntime + + metas: list[PageMetadata] = list(page_metadata or []) + file_records: list[tuple[Path, str]] = list(files or []) + + # Per-file lazy PDF openers so we don't re-open across pages. + pdf_docs: dict[int, Any] = {} + + # Per-file running page-within-file counter. For PDFs we emit one + # entry per PDF page in order; ``pages`` was built the same way by + # DocumentIngestor, so a parallel counter reconstructs the mapping. + per_file_cursor: dict[int, int] = {} + + rendered: list[PILImage.Image] = [] + try: + for idx, _page in enumerate(pages): + meta = metas[idx] if idx < len(metas) else PageMetadata() + file_index = meta.file_index + if file_index is None or file_index >= len(file_records): + # Text-only page — placeholder image; Surya returns empty. + rendered.append( + PILImageRuntime.new("RGB", (1, 1), color="white") + ) + continue + + local_path, mime = file_records[file_index] + if mime == "application/pdf": + doc = pdf_docs.get(file_index) + if doc is None: + import fitz # PyMuPDF + + doc = fitz.open(str(local_path)) + pdf_docs[file_index] = doc + pdf_page_no = per_file_cursor.get(file_index, 0) + per_file_cursor[file_index] = pdf_page_no + 1 + pdf_page = doc.load_page(pdf_page_no) + pix = pdf_page.get_pixmap(dpi=200) + img = PILImageRuntime.frombytes( + "RGB", (pix.width, pix.height), pix.samples + ) + rendered.append(img) + elif mime in ("image/png", "image/jpeg", "image/tiff"): + frame_no = per_file_cursor.get(file_index, 0) + per_file_cursor[file_index] = frame_no + 1 + img = PILImageRuntime.open(local_path) + # Handle multi-frame (TIFF) — seek to the right frame. + with contextlib.suppress(EOFError): + img.seek(frame_no) + rendered.append(img.convert("RGB")) + else: # pragma: no cover - ingestor already rejected + rendered.append( + PILImageRuntime.new("RGB", (1, 1), color="white") + ) + finally: + for doc in pdf_docs.values(): + with contextlib.suppress(Exception): + doc.close() + return rendered + + @staticmethod + def _flatten_polygon(polygon: Any) -> list[float]: + """Flatten ``[[x1,y1],[x2,y2],[x3,y3],[x4,y4]]`` → 8-float list. + + Surya emits 4 quad corners. The spec wants 8 raw-pixel coords so + downstream provenance normalisation can consume them directly. + """ + if not polygon: + return [] + flat: list[float] = [] + for point in polygon: + if isinstance(point, (list, tuple)) and len(point) >= 2: + flat.append(float(point[0])) + flat.append(float(point[1])) + return flat + + +__all__ = ["SuryaOCRClient"] diff --git a/src/ix/pipeline/ocr_step.py b/src/ix/pipeline/ocr_step.py index 53010e0..c3dba34 100644 --- a/src/ix/pipeline/ocr_step.py +++ b/src/ix/pipeline/ocr_step.py @@ -56,7 +56,11 @@ class OCRStep(Step): assert ctx is not None, "SetupStep must populate response_ix.context" pages = list(getattr(ctx, "pages", [])) - ocr_result = await self._client.ocr(pages) + files = list(getattr(ctx, "files", []) or []) + page_metadata = list(getattr(ctx, "page_metadata", []) or []) + ocr_result = await self._client.ocr( + pages, files=files, page_metadata=page_metadata + ) # Inject page tags around each OCR page's content so the LLM can # cross-reference the visual anchor without a separate prompt hack. diff --git a/tests/live/test_surya_client_live.py b/tests/live/test_surya_client_live.py new file mode 100644 index 0000000..999e49a --- /dev/null +++ b/tests/live/test_surya_client_live.py @@ -0,0 +1,83 @@ +"""Live test for :class:`SuryaOCRClient` — gated on ``IX_TEST_OLLAMA=1``. + +Downloads real Surya models (hundreds of MB) on first run. Never runs in +CI. Exercised locally with:: + + IX_TEST_OLLAMA=1 uv run pytest tests/live/test_surya_client_live.py -v + +Note: requires the ``[ocr]`` extra — ``uv sync --extra ocr --extra dev``. +""" + +from __future__ import annotations + +import os +from pathlib import Path + +import pytest + +from ix.contracts import Page +from ix.segmentation import PageMetadata + +pytestmark = [ + pytest.mark.live, + pytest.mark.skipif( + os.environ.get("IX_TEST_OLLAMA") != "1", + reason="live: IX_TEST_OLLAMA=1 required", + ), +] + + +async def test_extracts_dkb_and_iban_from_synthetic_giro() -> None: + """Real Surya run against ``tests/fixtures/synthetic_giro.pdf``. + + Assert the flat text contains ``"DKB"`` and the IBAN without spaces. + """ + from ix.ocr.surya_client import SuryaOCRClient + + fixture = Path(__file__).parent.parent / "fixtures" / "synthetic_giro.pdf" + assert fixture.exists(), f"missing fixture: {fixture}" + + # Build Pages the way DocumentIngestor would for this PDF: count pages + # via PyMuPDF so we pass the right number of inputs. + import fitz + + doc = fitz.open(str(fixture)) + try: + pages = [ + Page( + page_no=i + 1, + width=float(p.rect.width), + height=float(p.rect.height), + lines=[], + ) + for i, p in enumerate(doc) + ] + finally: + doc.close() + + client = SuryaOCRClient() + result = await client.ocr( + pages, + files=[(fixture, "application/pdf")], + page_metadata=[PageMetadata(file_index=0) for _ in pages], + ) + + flat_text = result.result.text or "" + # Join page-level line texts if flat text missing (shape-safety). + if not flat_text: + flat_text = "\n".join( + line.text or "" + for page in result.result.pages + for line in page.lines + ) + + assert "DKB" in flat_text + assert "DE89370400440532013000" in flat_text.replace(" ", "") + + +async def test_selfcheck_ok_against_real_predictors() -> None: + """``selfcheck()`` returns ``ok`` once Surya's predictors load.""" + from ix.ocr.surya_client import SuryaOCRClient + + client = SuryaOCRClient() + assert await client.selfcheck() == "ok" diff --git a/tests/unit/test_surya_client.py b/tests/unit/test_surya_client.py new file mode 100644 index 0000000..a713813 --- /dev/null +++ b/tests/unit/test_surya_client.py @@ -0,0 +1,166 @@ +"""Tests for :class:`SuryaOCRClient` — hermetic, no model download. + +The real Surya predictors are patched out with :class:`unittest.mock.MagicMock` +that return trivially-shaped line objects. The tests assert the client's +translation layer — flattening polygons, mapping text_lines → ``Line``, +preserving ``page_no``/``width``/``height`` per input page. +""" + +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from ix.contracts import Page +from ix.ocr.surya_client import SuryaOCRClient +from ix.segmentation import PageMetadata + + +def _make_surya_line(text: str, polygon: list[list[float]]) -> SimpleNamespace: + """Mimic ``surya.recognition.schema.TextLine`` duck-typing-style.""" + return SimpleNamespace(text=text, polygon=polygon, confidence=0.95) + + +def _make_surya_ocr_result(lines: list[SimpleNamespace]) -> SimpleNamespace: + """Mimic ``surya.recognition.schema.OCRResult``.""" + return SimpleNamespace(text_lines=lines, image_bbox=[0, 0, 100, 100]) + + +class TestOCRBuildsOCRResultFromMockedPredictors: + async def test_one_image_one_line_flatten_polygon(self, tmp_path: Path) -> None: + img_path = tmp_path / "a.png" + _write_tiny_png(img_path) + + mock_line = _make_surya_line( + text="hello", + polygon=[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], + ) + mock_predictor = MagicMock( + return_value=[_make_surya_ocr_result([mock_line])] + ) + + client = SuryaOCRClient() + # Skip the real warm_up; inject the mock directly. + client._recognition_predictor = mock_predictor + client._detection_predictor = MagicMock() + + pages = [Page(page_no=1, width=100.0, height=50.0, lines=[])] + result = await client.ocr( + pages, + files=[(img_path, "image/png")], + page_metadata=[PageMetadata(file_index=0)], + ) + + assert len(result.result.pages) == 1 + out_page = result.result.pages[0] + assert out_page.page_no == 1 + assert out_page.width == 100.0 + assert out_page.height == 50.0 + assert len(out_page.lines) == 1 + assert out_page.lines[0].text == "hello" + assert out_page.lines[0].bounding_box == [ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 + ] + + async def test_multiple_pages_preserves_order(self, tmp_path: Path) -> None: + img_a = tmp_path / "a.png" + img_b = tmp_path / "b.png" + _write_tiny_png(img_a) + _write_tiny_png(img_b) + + mock_predictor = MagicMock( + return_value=[ + _make_surya_ocr_result( + [_make_surya_line("a-line", [[0, 0], [1, 0], [1, 1], [0, 1]])] + ), + _make_surya_ocr_result( + [_make_surya_line("b-line", [[0, 0], [1, 0], [1, 1], [0, 1]])] + ), + ] + ) + + client = SuryaOCRClient() + client._recognition_predictor = mock_predictor + client._detection_predictor = MagicMock() + + pages = [ + Page(page_no=1, width=10.0, height=20.0, lines=[]), + Page(page_no=2, width=10.0, height=20.0, lines=[]), + ] + result = await client.ocr( + pages, + files=[(img_a, "image/png"), (img_b, "image/png")], + page_metadata=[ + PageMetadata(file_index=0), + PageMetadata(file_index=1), + ], + ) + + assert [p.lines[0].text for p in result.result.pages] == ["a-line", "b-line"] + + async def test_lazy_warm_up_on_first_ocr(self, tmp_path: Path) -> None: + img = tmp_path / "x.png" + _write_tiny_png(img) + + client = SuryaOCRClient() + + # Use patch.object on the instance's warm_up so we don't need real + # Surya module loading. + with patch.object(client, "warm_up", autospec=True) as mocked_warm_up: + # After warm_up is called, the predictors must be assigned. + def fake_warm_up(self: SuryaOCRClient) -> None: + self._recognition_predictor = MagicMock( + return_value=[ + _make_surya_ocr_result( + [ + _make_surya_line( + "hi", [[0, 0], [1, 0], [1, 1], [0, 1]] + ) + ] + ) + ] + ) + self._detection_predictor = MagicMock() + + mocked_warm_up.side_effect = lambda: fake_warm_up(client) + + pages = [Page(page_no=1, width=10.0, height=10.0, lines=[])] + await client.ocr( + pages, + files=[(img, "image/png")], + page_metadata=[PageMetadata(file_index=0)], + ) + mocked_warm_up.assert_called_once() + + +class TestSelfcheck: + async def test_selfcheck_ok_with_mocked_predictors(self) -> None: + client = SuryaOCRClient() + client._recognition_predictor = MagicMock( + return_value=[_make_surya_ocr_result([])] + ) + client._detection_predictor = MagicMock() + assert await client.selfcheck() == "ok" + + async def test_selfcheck_fail_when_predictor_raises(self) -> None: + client = SuryaOCRClient() + client._recognition_predictor = MagicMock( + side_effect=RuntimeError("cuda broken") + ) + client._detection_predictor = MagicMock() + assert await client.selfcheck() == "fail" + + +def _write_tiny_png(path: Path) -> None: + """Write a 2x2 white PNG so PIL can open it.""" + from PIL import Image + + Image.new("RGB", (2, 2), color="white").save(path, format="PNG") + + +@pytest.mark.parametrize("unused", [None]) # keep pytest happy if file ever runs alone +def test_module_imports(unused: None) -> None: + assert SuryaOCRClient is not None