Runs Surya's detection + recognition over PIL images rendered from each Page's source file (PDFs via PyMuPDF, images via Pillow). Lazy warm_up so FastAPI lifespan start stays predictable. Deferred Surya/torch imports keep the base install slim — the heavy deps stay under [ocr]. Extends OCRClient Protocol with optional files + page_metadata kwargs so the engine can resolve each page back to its on-disk source; Fake accepts-and-ignores to keep hermetic tests unchanged. selfcheck() runs the predictors on a 1x1 PIL image — wired into /healthz by Task 4.3. Tests: 6 hermetic unit tests (Surya predictors mocked, no model download); 2 live tests gated on IX_TEST_OLLAMA=1 (never run in CI). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
235 lines
9.2 KiB
Python
235 lines
9.2 KiB
Python
"""SuryaOCRClient — real :class:`OCRClient` backed by ``surya-ocr``.
|
|
|
|
Per spec §6.2: the MVP OCR engine. Runs Surya's detection + recognition
|
|
predictors over per-page PIL images rendered from the downloaded sources
|
|
(PDFs via PyMuPDF, images via Pillow).
|
|
|
|
Design choices:
|
|
|
|
* **Lazy model loading.** ``__init__`` is cheap; the heavy predictors are
|
|
built on first :meth:`ocr` / :meth:`selfcheck` / explicit :meth:`warm_up`.
|
|
This keeps FastAPI's lifespan predictable — ops can decide whether to
|
|
pay the load cost up front or on first request.
|
|
* **Device is Surya's default.** CUDA on the prod box, MPS on M-series Macs.
|
|
We deliberately don't pin.
|
|
* **No text-token reuse from PyMuPDF.** The cross-check against Paperless'
|
|
Tesseract output (ReliabilityStep's ``text_agreement``) is only meaningful
|
|
with a truly independent OCR pass, so we always render-and-recognize
|
|
even for PDFs that carry embedded text.
|
|
|
|
The ``surya-ocr`` package pulls torch + heavy model deps, so it's kept
|
|
behind the ``[ocr]`` extra. All Surya imports are deferred into
|
|
:meth:`warm_up` so running the unit tests (which patch the predictors)
|
|
doesn't require the package to be installed.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, Literal
|
|
|
|
from ix.contracts import Line, OCRDetails, OCRResult, Page
|
|
from ix.segmentation import PageMetadata
|
|
|
|
if TYPE_CHECKING: # pragma: no cover
|
|
from PIL import Image as PILImage
|
|
|
|
|
|
class SuryaOCRClient:
|
|
"""Surya-backed OCR engine.
|
|
|
|
Attributes are created lazily by :meth:`warm_up`. The unit tests inject
|
|
mocks directly onto ``_recognition_predictor`` / ``_detection_predictor``
|
|
to avoid the Surya import chain.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._recognition_predictor: Any = None
|
|
self._detection_predictor: Any = None
|
|
|
|
def warm_up(self) -> None:
|
|
"""Load the detection + recognition predictors. Idempotent.
|
|
|
|
Called automatically on the first :meth:`ocr` / :meth:`selfcheck`,
|
|
or explicitly from the app lifespan to front-load the cost.
|
|
"""
|
|
if (
|
|
self._recognition_predictor is not None
|
|
and self._detection_predictor is not None
|
|
):
|
|
return
|
|
|
|
# Deferred imports: only reachable when the optional [ocr] extra is
|
|
# installed. Keeping them inside the method so base-install unit
|
|
# tests (which patch the predictors) don't need surya on sys.path.
|
|
from surya.detection import DetectionPredictor # type: ignore[import-not-found]
|
|
from surya.foundation import FoundationPredictor # type: ignore[import-not-found]
|
|
from surya.recognition import RecognitionPredictor # type: ignore[import-not-found]
|
|
|
|
foundation = FoundationPredictor()
|
|
self._recognition_predictor = RecognitionPredictor(foundation)
|
|
self._detection_predictor = DetectionPredictor()
|
|
|
|
async def ocr(
|
|
self,
|
|
pages: list[Page],
|
|
*,
|
|
files: list[tuple[Path, str]] | None = None,
|
|
page_metadata: list[Any] | None = None,
|
|
) -> OCRResult:
|
|
"""Render each input page, run Surya, translate back to contracts."""
|
|
self.warm_up()
|
|
|
|
images = self._render_pages(pages, files, page_metadata)
|
|
|
|
# Surya is blocking — run it off the event loop.
|
|
loop = asyncio.get_running_loop()
|
|
surya_results = await loop.run_in_executor(
|
|
None, self._run_recognition, images
|
|
)
|
|
|
|
out_pages: list[Page] = []
|
|
all_text_fragments: list[str] = []
|
|
for input_page, surya_result in zip(pages, surya_results, strict=True):
|
|
lines: list[Line] = []
|
|
for tl in getattr(surya_result, "text_lines", []) or []:
|
|
flat = self._flatten_polygon(getattr(tl, "polygon", None))
|
|
text = getattr(tl, "text", None)
|
|
lines.append(Line(text=text, bounding_box=flat))
|
|
if text:
|
|
all_text_fragments.append(text)
|
|
out_pages.append(
|
|
Page(
|
|
page_no=input_page.page_no,
|
|
width=input_page.width,
|
|
height=input_page.height,
|
|
angle=input_page.angle,
|
|
unit=input_page.unit,
|
|
lines=lines,
|
|
)
|
|
)
|
|
|
|
details = OCRDetails(
|
|
text="\n".join(all_text_fragments) if all_text_fragments else None,
|
|
pages=out_pages,
|
|
)
|
|
return OCRResult(result=details, meta_data={"engine": "surya"})
|
|
|
|
async def selfcheck(self) -> Literal["ok", "fail"]:
|
|
"""Run the predictors on a 1x1 image to confirm the stack works."""
|
|
try:
|
|
self.warm_up()
|
|
except Exception:
|
|
return "fail"
|
|
|
|
try:
|
|
from PIL import Image as PILImageRuntime
|
|
|
|
img = PILImageRuntime.new("RGB", (1, 1), color="white")
|
|
loop = asyncio.get_running_loop()
|
|
await loop.run_in_executor(None, self._run_recognition, [img])
|
|
except Exception:
|
|
return "fail"
|
|
return "ok"
|
|
|
|
def _run_recognition(self, images: list[PILImage.Image]) -> list[Any]:
|
|
"""Invoke the recognition predictor. Kept tiny for threadpool offload."""
|
|
return list(
|
|
self._recognition_predictor(
|
|
images, det_predictor=self._detection_predictor
|
|
)
|
|
)
|
|
|
|
def _render_pages(
|
|
self,
|
|
pages: list[Page],
|
|
files: list[tuple[Path, str]] | None,
|
|
page_metadata: list[Any] | None,
|
|
) -> list[PILImage.Image]:
|
|
"""Render each input :class:`Page` to a PIL image.
|
|
|
|
We walk pages + page_metadata in lockstep so we know which source
|
|
file each page came from and (for PDFs) what page-index to render.
|
|
Text-only pages (``file_index is None``) get a blank 1x1 placeholder
|
|
so Surya returns an empty result and downstream code still gets one
|
|
entry per input page.
|
|
"""
|
|
from PIL import Image as PILImageRuntime
|
|
|
|
metas: list[PageMetadata] = list(page_metadata or [])
|
|
file_records: list[tuple[Path, str]] = list(files or [])
|
|
|
|
# Per-file lazy PDF openers so we don't re-open across pages.
|
|
pdf_docs: dict[int, Any] = {}
|
|
|
|
# Per-file running page-within-file counter. For PDFs we emit one
|
|
# entry per PDF page in order; ``pages`` was built the same way by
|
|
# DocumentIngestor, so a parallel counter reconstructs the mapping.
|
|
per_file_cursor: dict[int, int] = {}
|
|
|
|
rendered: list[PILImage.Image] = []
|
|
try:
|
|
for idx, _page in enumerate(pages):
|
|
meta = metas[idx] if idx < len(metas) else PageMetadata()
|
|
file_index = meta.file_index
|
|
if file_index is None or file_index >= len(file_records):
|
|
# Text-only page — placeholder image; Surya returns empty.
|
|
rendered.append(
|
|
PILImageRuntime.new("RGB", (1, 1), color="white")
|
|
)
|
|
continue
|
|
|
|
local_path, mime = file_records[file_index]
|
|
if mime == "application/pdf":
|
|
doc = pdf_docs.get(file_index)
|
|
if doc is None:
|
|
import fitz # PyMuPDF
|
|
|
|
doc = fitz.open(str(local_path))
|
|
pdf_docs[file_index] = doc
|
|
pdf_page_no = per_file_cursor.get(file_index, 0)
|
|
per_file_cursor[file_index] = pdf_page_no + 1
|
|
pdf_page = doc.load_page(pdf_page_no)
|
|
pix = pdf_page.get_pixmap(dpi=200)
|
|
img = PILImageRuntime.frombytes(
|
|
"RGB", (pix.width, pix.height), pix.samples
|
|
)
|
|
rendered.append(img)
|
|
elif mime in ("image/png", "image/jpeg", "image/tiff"):
|
|
frame_no = per_file_cursor.get(file_index, 0)
|
|
per_file_cursor[file_index] = frame_no + 1
|
|
img = PILImageRuntime.open(local_path)
|
|
# Handle multi-frame (TIFF) — seek to the right frame.
|
|
with contextlib.suppress(EOFError):
|
|
img.seek(frame_no)
|
|
rendered.append(img.convert("RGB"))
|
|
else: # pragma: no cover - ingestor already rejected
|
|
rendered.append(
|
|
PILImageRuntime.new("RGB", (1, 1), color="white")
|
|
)
|
|
finally:
|
|
for doc in pdf_docs.values():
|
|
with contextlib.suppress(Exception):
|
|
doc.close()
|
|
return rendered
|
|
|
|
@staticmethod
|
|
def _flatten_polygon(polygon: Any) -> list[float]:
|
|
"""Flatten ``[[x1,y1],[x2,y2],[x3,y3],[x4,y4]]`` → 8-float list.
|
|
|
|
Surya emits 4 quad corners. The spec wants 8 raw-pixel coords so
|
|
downstream provenance normalisation can consume them directly.
|
|
"""
|
|
if not polygon:
|
|
return []
|
|
flat: list[float] = []
|
|
for point in polygon:
|
|
if isinstance(point, (list, tuple)) and len(point) >= 2:
|
|
flat.append(float(point[0]))
|
|
flat.append(float(point[1]))
|
|
return flat
|
|
|
|
|
|
__all__ = ["SuryaOCRClient"]
|