Compare commits
No commits in common. "b737ed7b21eb4c46067275758f5bd32d1961c819" and "0f045f814af2cbaf9c85cde4691a34c48d9b59fd" have entirely different histories.
b737ed7b21
...
0f045f814a
6 changed files with 6 additions and 521 deletions
|
|
@ -5,19 +5,11 @@ method satisfies the Protocol. :class:`~ix.pipeline.ocr_step.OCRStep`
|
||||||
depends on the Protocol, not a concrete class, so swapping engines
|
depends on the Protocol, not a concrete class, so swapping engines
|
||||||
(``FakeOCRClient`` in tests, ``SuryaOCRClient`` in prod) stays a wiring
|
(``FakeOCRClient`` in tests, ``SuryaOCRClient`` in prod) stays a wiring
|
||||||
change at the app factory.
|
change at the app factory.
|
||||||
|
|
||||||
Per-page source location (``files`` + ``page_metadata``) flows in as
|
|
||||||
optional kwargs: fakes ignore them; the real
|
|
||||||
:class:`~ix.ocr.surya_client.SuryaOCRClient` uses them to render each
|
|
||||||
page's pixels back from disk. Keeping these optional lets unit tests stay
|
|
||||||
pages-only while production wiring (Task 4.3) plumbs through the real
|
|
||||||
filesystem handles.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pathlib import Path
|
from typing import Protocol, runtime_checkable
|
||||||
from typing import Any, Protocol, runtime_checkable
|
|
||||||
|
|
||||||
from ix.contracts import OCRResult, Page
|
from ix.contracts import OCRResult, Page
|
||||||
|
|
||||||
|
|
@ -32,18 +24,8 @@ class OCRClient(Protocol):
|
||||||
per input page (in the same order).
|
per input page (in the same order).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def ocr(
|
async def ocr(self, pages: list[Page]) -> OCRResult:
|
||||||
self,
|
"""Run OCR over the input pages; return the structured result."""
|
||||||
pages: list[Page],
|
|
||||||
*,
|
|
||||||
files: list[tuple[Path, str]] | None = None,
|
|
||||||
page_metadata: list[Any] | None = None,
|
|
||||||
) -> OCRResult:
|
|
||||||
"""Run OCR over the input pages; return the structured result.
|
|
||||||
|
|
||||||
``files`` and ``page_metadata`` are optional for hermetic tests;
|
|
||||||
real engines that need to re-render from disk read them.
|
|
||||||
"""
|
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,17 +30,8 @@ class FakeOCRClient:
|
||||||
self._canned = canned
|
self._canned = canned
|
||||||
self._raise_on_call = raise_on_call
|
self._raise_on_call = raise_on_call
|
||||||
|
|
||||||
async def ocr(
|
async def ocr(self, pages: list[Page]) -> OCRResult:
|
||||||
self,
|
"""Return the canned result or raise the configured error."""
|
||||||
pages: list[Page],
|
|
||||||
**_kwargs: object,
|
|
||||||
) -> OCRResult:
|
|
||||||
"""Return the canned result or raise the configured error.
|
|
||||||
|
|
||||||
Accepts (and ignores) any keyword args the production Protocol may
|
|
||||||
carry — keeps the fake swappable for :class:`SuryaOCRClient` at
|
|
||||||
call sites that pass ``files`` / ``page_metadata``.
|
|
||||||
"""
|
|
||||||
if self._raise_on_call is not None:
|
if self._raise_on_call is not None:
|
||||||
raise self._raise_on_call
|
raise self._raise_on_call
|
||||||
return self._canned
|
return self._canned
|
||||||
|
|
|
||||||
|
|
@ -1,235 +0,0 @@
|
||||||
"""SuryaOCRClient — real :class:`OCRClient` backed by ``surya-ocr``.
|
|
||||||
|
|
||||||
Per spec §6.2: the MVP OCR engine. Runs Surya's detection + recognition
|
|
||||||
predictors over per-page PIL images rendered from the downloaded sources
|
|
||||||
(PDFs via PyMuPDF, images via Pillow).
|
|
||||||
|
|
||||||
Design choices:
|
|
||||||
|
|
||||||
* **Lazy model loading.** ``__init__`` is cheap; the heavy predictors are
|
|
||||||
built on first :meth:`ocr` / :meth:`selfcheck` / explicit :meth:`warm_up`.
|
|
||||||
This keeps FastAPI's lifespan predictable — ops can decide whether to
|
|
||||||
pay the load cost up front or on first request.
|
|
||||||
* **Device is Surya's default.** CUDA on the prod box, MPS on M-series Macs.
|
|
||||||
We deliberately don't pin.
|
|
||||||
* **No text-token reuse from PyMuPDF.** The cross-check against Paperless'
|
|
||||||
Tesseract output (ReliabilityStep's ``text_agreement``) is only meaningful
|
|
||||||
with a truly independent OCR pass, so we always render-and-recognize
|
|
||||||
even for PDFs that carry embedded text.
|
|
||||||
|
|
||||||
The ``surya-ocr`` package pulls torch + heavy model deps, so it's kept
|
|
||||||
behind the ``[ocr]`` extra. All Surya imports are deferred into
|
|
||||||
:meth:`warm_up` so running the unit tests (which patch the predictors)
|
|
||||||
doesn't require the package to be installed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import contextlib
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
|
||||||
|
|
||||||
from ix.contracts import Line, OCRDetails, OCRResult, Page
|
|
||||||
from ix.segmentation import PageMetadata
|
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
|
||||||
from PIL import Image as PILImage
|
|
||||||
|
|
||||||
|
|
||||||
class SuryaOCRClient:
|
|
||||||
"""Surya-backed OCR engine.
|
|
||||||
|
|
||||||
Attributes are created lazily by :meth:`warm_up`. The unit tests inject
|
|
||||||
mocks directly onto ``_recognition_predictor`` / ``_detection_predictor``
|
|
||||||
to avoid the Surya import chain.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._recognition_predictor: Any = None
|
|
||||||
self._detection_predictor: Any = None
|
|
||||||
|
|
||||||
def warm_up(self) -> None:
|
|
||||||
"""Load the detection + recognition predictors. Idempotent.
|
|
||||||
|
|
||||||
Called automatically on the first :meth:`ocr` / :meth:`selfcheck`,
|
|
||||||
or explicitly from the app lifespan to front-load the cost.
|
|
||||||
"""
|
|
||||||
if (
|
|
||||||
self._recognition_predictor is not None
|
|
||||||
and self._detection_predictor is not None
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Deferred imports: only reachable when the optional [ocr] extra is
|
|
||||||
# installed. Keeping them inside the method so base-install unit
|
|
||||||
# tests (which patch the predictors) don't need surya on sys.path.
|
|
||||||
from surya.detection import DetectionPredictor # type: ignore[import-not-found]
|
|
||||||
from surya.foundation import FoundationPredictor # type: ignore[import-not-found]
|
|
||||||
from surya.recognition import RecognitionPredictor # type: ignore[import-not-found]
|
|
||||||
|
|
||||||
foundation = FoundationPredictor()
|
|
||||||
self._recognition_predictor = RecognitionPredictor(foundation)
|
|
||||||
self._detection_predictor = DetectionPredictor()
|
|
||||||
|
|
||||||
async def ocr(
|
|
||||||
self,
|
|
||||||
pages: list[Page],
|
|
||||||
*,
|
|
||||||
files: list[tuple[Path, str]] | None = None,
|
|
||||||
page_metadata: list[Any] | None = None,
|
|
||||||
) -> OCRResult:
|
|
||||||
"""Render each input page, run Surya, translate back to contracts."""
|
|
||||||
self.warm_up()
|
|
||||||
|
|
||||||
images = self._render_pages(pages, files, page_metadata)
|
|
||||||
|
|
||||||
# Surya is blocking — run it off the event loop.
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
surya_results = await loop.run_in_executor(
|
|
||||||
None, self._run_recognition, images
|
|
||||||
)
|
|
||||||
|
|
||||||
out_pages: list[Page] = []
|
|
||||||
all_text_fragments: list[str] = []
|
|
||||||
for input_page, surya_result in zip(pages, surya_results, strict=True):
|
|
||||||
lines: list[Line] = []
|
|
||||||
for tl in getattr(surya_result, "text_lines", []) or []:
|
|
||||||
flat = self._flatten_polygon(getattr(tl, "polygon", None))
|
|
||||||
text = getattr(tl, "text", None)
|
|
||||||
lines.append(Line(text=text, bounding_box=flat))
|
|
||||||
if text:
|
|
||||||
all_text_fragments.append(text)
|
|
||||||
out_pages.append(
|
|
||||||
Page(
|
|
||||||
page_no=input_page.page_no,
|
|
||||||
width=input_page.width,
|
|
||||||
height=input_page.height,
|
|
||||||
angle=input_page.angle,
|
|
||||||
unit=input_page.unit,
|
|
||||||
lines=lines,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
details = OCRDetails(
|
|
||||||
text="\n".join(all_text_fragments) if all_text_fragments else None,
|
|
||||||
pages=out_pages,
|
|
||||||
)
|
|
||||||
return OCRResult(result=details, meta_data={"engine": "surya"})
|
|
||||||
|
|
||||||
async def selfcheck(self) -> Literal["ok", "fail"]:
|
|
||||||
"""Run the predictors on a 1x1 image to confirm the stack works."""
|
|
||||||
try:
|
|
||||||
self.warm_up()
|
|
||||||
except Exception:
|
|
||||||
return "fail"
|
|
||||||
|
|
||||||
try:
|
|
||||||
from PIL import Image as PILImageRuntime
|
|
||||||
|
|
||||||
img = PILImageRuntime.new("RGB", (1, 1), color="white")
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
await loop.run_in_executor(None, self._run_recognition, [img])
|
|
||||||
except Exception:
|
|
||||||
return "fail"
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
def _run_recognition(self, images: list[PILImage.Image]) -> list[Any]:
|
|
||||||
"""Invoke the recognition predictor. Kept tiny for threadpool offload."""
|
|
||||||
return list(
|
|
||||||
self._recognition_predictor(
|
|
||||||
images, det_predictor=self._detection_predictor
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _render_pages(
|
|
||||||
self,
|
|
||||||
pages: list[Page],
|
|
||||||
files: list[tuple[Path, str]] | None,
|
|
||||||
page_metadata: list[Any] | None,
|
|
||||||
) -> list[PILImage.Image]:
|
|
||||||
"""Render each input :class:`Page` to a PIL image.
|
|
||||||
|
|
||||||
We walk pages + page_metadata in lockstep so we know which source
|
|
||||||
file each page came from and (for PDFs) what page-index to render.
|
|
||||||
Text-only pages (``file_index is None``) get a blank 1x1 placeholder
|
|
||||||
so Surya returns an empty result and downstream code still gets one
|
|
||||||
entry per input page.
|
|
||||||
"""
|
|
||||||
from PIL import Image as PILImageRuntime
|
|
||||||
|
|
||||||
metas: list[PageMetadata] = list(page_metadata or [])
|
|
||||||
file_records: list[tuple[Path, str]] = list(files or [])
|
|
||||||
|
|
||||||
# Per-file lazy PDF openers so we don't re-open across pages.
|
|
||||||
pdf_docs: dict[int, Any] = {}
|
|
||||||
|
|
||||||
# Per-file running page-within-file counter. For PDFs we emit one
|
|
||||||
# entry per PDF page in order; ``pages`` was built the same way by
|
|
||||||
# DocumentIngestor, so a parallel counter reconstructs the mapping.
|
|
||||||
per_file_cursor: dict[int, int] = {}
|
|
||||||
|
|
||||||
rendered: list[PILImage.Image] = []
|
|
||||||
try:
|
|
||||||
for idx, _page in enumerate(pages):
|
|
||||||
meta = metas[idx] if idx < len(metas) else PageMetadata()
|
|
||||||
file_index = meta.file_index
|
|
||||||
if file_index is None or file_index >= len(file_records):
|
|
||||||
# Text-only page — placeholder image; Surya returns empty.
|
|
||||||
rendered.append(
|
|
||||||
PILImageRuntime.new("RGB", (1, 1), color="white")
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
local_path, mime = file_records[file_index]
|
|
||||||
if mime == "application/pdf":
|
|
||||||
doc = pdf_docs.get(file_index)
|
|
||||||
if doc is None:
|
|
||||||
import fitz # PyMuPDF
|
|
||||||
|
|
||||||
doc = fitz.open(str(local_path))
|
|
||||||
pdf_docs[file_index] = doc
|
|
||||||
pdf_page_no = per_file_cursor.get(file_index, 0)
|
|
||||||
per_file_cursor[file_index] = pdf_page_no + 1
|
|
||||||
pdf_page = doc.load_page(pdf_page_no)
|
|
||||||
pix = pdf_page.get_pixmap(dpi=200)
|
|
||||||
img = PILImageRuntime.frombytes(
|
|
||||||
"RGB", (pix.width, pix.height), pix.samples
|
|
||||||
)
|
|
||||||
rendered.append(img)
|
|
||||||
elif mime in ("image/png", "image/jpeg", "image/tiff"):
|
|
||||||
frame_no = per_file_cursor.get(file_index, 0)
|
|
||||||
per_file_cursor[file_index] = frame_no + 1
|
|
||||||
img = PILImageRuntime.open(local_path)
|
|
||||||
# Handle multi-frame (TIFF) — seek to the right frame.
|
|
||||||
with contextlib.suppress(EOFError):
|
|
||||||
img.seek(frame_no)
|
|
||||||
rendered.append(img.convert("RGB"))
|
|
||||||
else: # pragma: no cover - ingestor already rejected
|
|
||||||
rendered.append(
|
|
||||||
PILImageRuntime.new("RGB", (1, 1), color="white")
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
for doc in pdf_docs.values():
|
|
||||||
with contextlib.suppress(Exception):
|
|
||||||
doc.close()
|
|
||||||
return rendered
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _flatten_polygon(polygon: Any) -> list[float]:
|
|
||||||
"""Flatten ``[[x1,y1],[x2,y2],[x3,y3],[x4,y4]]`` → 8-float list.
|
|
||||||
|
|
||||||
Surya emits 4 quad corners. The spec wants 8 raw-pixel coords so
|
|
||||||
downstream provenance normalisation can consume them directly.
|
|
||||||
"""
|
|
||||||
if not polygon:
|
|
||||||
return []
|
|
||||||
flat: list[float] = []
|
|
||||||
for point in polygon:
|
|
||||||
if isinstance(point, (list, tuple)) and len(point) >= 2:
|
|
||||||
flat.append(float(point[0]))
|
|
||||||
flat.append(float(point[1]))
|
|
||||||
return flat
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["SuryaOCRClient"]
|
|
||||||
|
|
@ -56,11 +56,7 @@ class OCRStep(Step):
|
||||||
assert ctx is not None, "SetupStep must populate response_ix.context"
|
assert ctx is not None, "SetupStep must populate response_ix.context"
|
||||||
|
|
||||||
pages = list(getattr(ctx, "pages", []))
|
pages = list(getattr(ctx, "pages", []))
|
||||||
files = list(getattr(ctx, "files", []) or [])
|
ocr_result = await self._client.ocr(pages)
|
||||||
page_metadata = list(getattr(ctx, "page_metadata", []) or [])
|
|
||||||
ocr_result = await self._client.ocr(
|
|
||||||
pages, files=files, page_metadata=page_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
# Inject page tags around each OCR page's content so the LLM can
|
# Inject page tags around each OCR page's content so the LLM can
|
||||||
# cross-reference the visual anchor without a separate prompt hack.
|
# cross-reference the visual anchor without a separate prompt hack.
|
||||||
|
|
|
||||||
|
|
@ -1,83 +0,0 @@
|
||||||
"""Live test for :class:`SuryaOCRClient` — gated on ``IX_TEST_OLLAMA=1``.
|
|
||||||
|
|
||||||
Downloads real Surya models (hundreds of MB) on first run. Never runs in
|
|
||||||
CI. Exercised locally with::
|
|
||||||
|
|
||||||
IX_TEST_OLLAMA=1 uv run pytest tests/live/test_surya_client_live.py -v
|
|
||||||
|
|
||||||
Note: requires the ``[ocr]`` extra — ``uv sync --extra ocr --extra dev``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from ix.contracts import Page
|
|
||||||
from ix.segmentation import PageMetadata
|
|
||||||
|
|
||||||
pytestmark = [
|
|
||||||
pytest.mark.live,
|
|
||||||
pytest.mark.skipif(
|
|
||||||
os.environ.get("IX_TEST_OLLAMA") != "1",
|
|
||||||
reason="live: IX_TEST_OLLAMA=1 required",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def test_extracts_dkb_and_iban_from_synthetic_giro() -> None:
|
|
||||||
"""Real Surya run against ``tests/fixtures/synthetic_giro.pdf``.
|
|
||||||
|
|
||||||
Assert the flat text contains ``"DKB"`` and the IBAN without spaces.
|
|
||||||
"""
|
|
||||||
from ix.ocr.surya_client import SuryaOCRClient
|
|
||||||
|
|
||||||
fixture = Path(__file__).parent.parent / "fixtures" / "synthetic_giro.pdf"
|
|
||||||
assert fixture.exists(), f"missing fixture: {fixture}"
|
|
||||||
|
|
||||||
# Build Pages the way DocumentIngestor would for this PDF: count pages
|
|
||||||
# via PyMuPDF so we pass the right number of inputs.
|
|
||||||
import fitz
|
|
||||||
|
|
||||||
doc = fitz.open(str(fixture))
|
|
||||||
try:
|
|
||||||
pages = [
|
|
||||||
Page(
|
|
||||||
page_no=i + 1,
|
|
||||||
width=float(p.rect.width),
|
|
||||||
height=float(p.rect.height),
|
|
||||||
lines=[],
|
|
||||||
)
|
|
||||||
for i, p in enumerate(doc)
|
|
||||||
]
|
|
||||||
finally:
|
|
||||||
doc.close()
|
|
||||||
|
|
||||||
client = SuryaOCRClient()
|
|
||||||
result = await client.ocr(
|
|
||||||
pages,
|
|
||||||
files=[(fixture, "application/pdf")],
|
|
||||||
page_metadata=[PageMetadata(file_index=0) for _ in pages],
|
|
||||||
)
|
|
||||||
|
|
||||||
flat_text = result.result.text or ""
|
|
||||||
# Join page-level line texts if flat text missing (shape-safety).
|
|
||||||
if not flat_text:
|
|
||||||
flat_text = "\n".join(
|
|
||||||
line.text or ""
|
|
||||||
for page in result.result.pages
|
|
||||||
for line in page.lines
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "DKB" in flat_text
|
|
||||||
assert "DE89370400440532013000" in flat_text.replace(" ", "")
|
|
||||||
|
|
||||||
|
|
||||||
async def test_selfcheck_ok_against_real_predictors() -> None:
|
|
||||||
"""``selfcheck()`` returns ``ok`` once Surya's predictors load."""
|
|
||||||
from ix.ocr.surya_client import SuryaOCRClient
|
|
||||||
|
|
||||||
client = SuryaOCRClient()
|
|
||||||
assert await client.selfcheck() == "ok"
|
|
||||||
|
|
@ -1,166 +0,0 @@
|
||||||
"""Tests for :class:`SuryaOCRClient` — hermetic, no model download.
|
|
||||||
|
|
||||||
The real Surya predictors are patched out with :class:`unittest.mock.MagicMock`
|
|
||||||
that return trivially-shaped line objects. The tests assert the client's
|
|
||||||
translation layer — flattening polygons, mapping text_lines → ``Line``,
|
|
||||||
preserving ``page_no``/``width``/``height`` per input page.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from ix.contracts import Page
|
|
||||||
from ix.ocr.surya_client import SuryaOCRClient
|
|
||||||
from ix.segmentation import PageMetadata
|
|
||||||
|
|
||||||
|
|
||||||
def _make_surya_line(text: str, polygon: list[list[float]]) -> SimpleNamespace:
|
|
||||||
"""Mimic ``surya.recognition.schema.TextLine`` duck-typing-style."""
|
|
||||||
return SimpleNamespace(text=text, polygon=polygon, confidence=0.95)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_surya_ocr_result(lines: list[SimpleNamespace]) -> SimpleNamespace:
|
|
||||||
"""Mimic ``surya.recognition.schema.OCRResult``."""
|
|
||||||
return SimpleNamespace(text_lines=lines, image_bbox=[0, 0, 100, 100])
|
|
||||||
|
|
||||||
|
|
||||||
class TestOCRBuildsOCRResultFromMockedPredictors:
|
|
||||||
async def test_one_image_one_line_flatten_polygon(self, tmp_path: Path) -> None:
|
|
||||||
img_path = tmp_path / "a.png"
|
|
||||||
_write_tiny_png(img_path)
|
|
||||||
|
|
||||||
mock_line = _make_surya_line(
|
|
||||||
text="hello",
|
|
||||||
polygon=[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],
|
|
||||||
)
|
|
||||||
mock_predictor = MagicMock(
|
|
||||||
return_value=[_make_surya_ocr_result([mock_line])]
|
|
||||||
)
|
|
||||||
|
|
||||||
client = SuryaOCRClient()
|
|
||||||
# Skip the real warm_up; inject the mock directly.
|
|
||||||
client._recognition_predictor = mock_predictor
|
|
||||||
client._detection_predictor = MagicMock()
|
|
||||||
|
|
||||||
pages = [Page(page_no=1, width=100.0, height=50.0, lines=[])]
|
|
||||||
result = await client.ocr(
|
|
||||||
pages,
|
|
||||||
files=[(img_path, "image/png")],
|
|
||||||
page_metadata=[PageMetadata(file_index=0)],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(result.result.pages) == 1
|
|
||||||
out_page = result.result.pages[0]
|
|
||||||
assert out_page.page_no == 1
|
|
||||||
assert out_page.width == 100.0
|
|
||||||
assert out_page.height == 50.0
|
|
||||||
assert len(out_page.lines) == 1
|
|
||||||
assert out_page.lines[0].text == "hello"
|
|
||||||
assert out_page.lines[0].bounding_box == [
|
|
||||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0
|
|
||||||
]
|
|
||||||
|
|
||||||
async def test_multiple_pages_preserves_order(self, tmp_path: Path) -> None:
|
|
||||||
img_a = tmp_path / "a.png"
|
|
||||||
img_b = tmp_path / "b.png"
|
|
||||||
_write_tiny_png(img_a)
|
|
||||||
_write_tiny_png(img_b)
|
|
||||||
|
|
||||||
mock_predictor = MagicMock(
|
|
||||||
return_value=[
|
|
||||||
_make_surya_ocr_result(
|
|
||||||
[_make_surya_line("a-line", [[0, 0], [1, 0], [1, 1], [0, 1]])]
|
|
||||||
),
|
|
||||||
_make_surya_ocr_result(
|
|
||||||
[_make_surya_line("b-line", [[0, 0], [1, 0], [1, 1], [0, 1]])]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
client = SuryaOCRClient()
|
|
||||||
client._recognition_predictor = mock_predictor
|
|
||||||
client._detection_predictor = MagicMock()
|
|
||||||
|
|
||||||
pages = [
|
|
||||||
Page(page_no=1, width=10.0, height=20.0, lines=[]),
|
|
||||||
Page(page_no=2, width=10.0, height=20.0, lines=[]),
|
|
||||||
]
|
|
||||||
result = await client.ocr(
|
|
||||||
pages,
|
|
||||||
files=[(img_a, "image/png"), (img_b, "image/png")],
|
|
||||||
page_metadata=[
|
|
||||||
PageMetadata(file_index=0),
|
|
||||||
PageMetadata(file_index=1),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert [p.lines[0].text for p in result.result.pages] == ["a-line", "b-line"]
|
|
||||||
|
|
||||||
async def test_lazy_warm_up_on_first_ocr(self, tmp_path: Path) -> None:
|
|
||||||
img = tmp_path / "x.png"
|
|
||||||
_write_tiny_png(img)
|
|
||||||
|
|
||||||
client = SuryaOCRClient()
|
|
||||||
|
|
||||||
# Use patch.object on the instance's warm_up so we don't need real
|
|
||||||
# Surya module loading.
|
|
||||||
with patch.object(client, "warm_up", autospec=True) as mocked_warm_up:
|
|
||||||
# After warm_up is called, the predictors must be assigned.
|
|
||||||
def fake_warm_up(self: SuryaOCRClient) -> None:
|
|
||||||
self._recognition_predictor = MagicMock(
|
|
||||||
return_value=[
|
|
||||||
_make_surya_ocr_result(
|
|
||||||
[
|
|
||||||
_make_surya_line(
|
|
||||||
"hi", [[0, 0], [1, 0], [1, 1], [0, 1]]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self._detection_predictor = MagicMock()
|
|
||||||
|
|
||||||
mocked_warm_up.side_effect = lambda: fake_warm_up(client)
|
|
||||||
|
|
||||||
pages = [Page(page_no=1, width=10.0, height=10.0, lines=[])]
|
|
||||||
await client.ocr(
|
|
||||||
pages,
|
|
||||||
files=[(img, "image/png")],
|
|
||||||
page_metadata=[PageMetadata(file_index=0)],
|
|
||||||
)
|
|
||||||
mocked_warm_up.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
class TestSelfcheck:
|
|
||||||
async def test_selfcheck_ok_with_mocked_predictors(self) -> None:
|
|
||||||
client = SuryaOCRClient()
|
|
||||||
client._recognition_predictor = MagicMock(
|
|
||||||
return_value=[_make_surya_ocr_result([])]
|
|
||||||
)
|
|
||||||
client._detection_predictor = MagicMock()
|
|
||||||
assert await client.selfcheck() == "ok"
|
|
||||||
|
|
||||||
async def test_selfcheck_fail_when_predictor_raises(self) -> None:
|
|
||||||
client = SuryaOCRClient()
|
|
||||||
client._recognition_predictor = MagicMock(
|
|
||||||
side_effect=RuntimeError("cuda broken")
|
|
||||||
)
|
|
||||||
client._detection_predictor = MagicMock()
|
|
||||||
assert await client.selfcheck() == "fail"
|
|
||||||
|
|
||||||
|
|
||||||
def _write_tiny_png(path: Path) -> None:
|
|
||||||
"""Write a 2x2 white PNG so PIL can open it."""
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
Image.new("RGB", (2, 2), color="white").save(path, format="PNG")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("unused", [None]) # keep pytest happy if file ever runs alone
|
|
||||||
def test_module_imports(unused: None) -> None:
|
|
||||||
assert SuryaOCRClient is not None
|
|
||||||
Loading…
Reference in a new issue