From 81054baa06f8022d0a5e040394b54bc2dc69fd1a Mon Sep 17 00:00:00 2001 From: Dirk Riemann Date: Sat, 18 Apr 2026 11:15:46 +0200 Subject: [PATCH] =?UTF-8?q?feat(pipeline):=20OCRStep=20=E2=80=94=20run=20O?= =?UTF-8?q?CR=20+=20page=20tags=20+=20SegmentIndex=20(spec=20=C2=A76.2)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Runs after SetupStep. Dispatches the flat page list to the injected OCRClient, writes the raw OCRResult onto response.ocr_result, injects open/close tag lines around each page's content, and builds a SegmentIndex over the non-tag lines when provenance is on. Validate follows the spec triad rule: - include_geometries/include_ocr_text/ocr_only + no files -> IX_000_004 - no files -> False (skip) - files + (use_ocr or triad) -> True 9 unit tests in tests/unit/test_ocr_step.py cover all three validate branches, OCRResult written, page tags injected (format + file_index), SegmentIndex built iff provenance on. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/ix/pipeline/ocr_step.py | 91 +++++++++++++++++ tests/unit/test_ocr_step.py | 199 ++++++++++++++++++++++++++++++++++++ 2 files changed, 290 insertions(+) create mode 100644 src/ix/pipeline/ocr_step.py create mode 100644 tests/unit/test_ocr_step.py diff --git a/src/ix/pipeline/ocr_step.py b/src/ix/pipeline/ocr_step.py new file mode 100644 index 0000000..53010e0 --- /dev/null +++ b/src/ix/pipeline/ocr_step.py @@ -0,0 +1,91 @@ +"""OCRStep — run OCR, inject page tags, build the SegmentIndex (spec §6.2). + +Runs after :class:`~ix.pipeline.setup_step.SetupStep`. Three things +happen: + +1. Dispatch the flat page list to the injected :class:`OCRClient` and + write the raw :class:`~ix.contracts.OCRResult` onto the response. +2. Inject ```` / ```` tag lines + around each page's content so the GenAIStep can ground citations + visually (spec §6.2c). +3. When provenance is on, build a :class:`SegmentIndex` over the + *non-tag* lines and stash it on the internal context. + +Validation follows the triad-or-use_ocr rule from the spec: if any of +``include_geometries`` / ``include_ocr_text`` / ``ocr_only`` is set but +``context.files`` is empty, raise ``IX_000_004``. If OCR isn't wanted +(text-only request), return ``False`` to skip the step silently. +""" + +from __future__ import annotations + +from ix.contracts import Line, RequestIX, ResponseIX +from ix.errors import IXErrorCode, IXException +from ix.ocr.client import OCRClient +from ix.pipeline.step import Step +from ix.segmentation import PageMetadata, SegmentIndex + + +class OCRStep(Step): + """Inject-and-index OCR stage.""" + + def __init__(self, ocr_client: OCRClient) -> None: + self._client = ocr_client + + async def validate(self, request_ix: RequestIX, response_ix: ResponseIX) -> bool: + opts = request_ix.options.ocr + ctx = response_ix.context + files = list(getattr(ctx, "files", [])) if ctx is not None else [] + + ocr_artifacts_requested = ( + opts.include_geometries or opts.include_ocr_text or opts.ocr_only + ) + if ocr_artifacts_requested and not files: + raise IXException(IXErrorCode.IX_000_004) + + if not files: + return False + + # OCR runs if use_ocr OR any of the artifact flags is set. + return bool(opts.use_ocr or ocr_artifacts_requested) + + async def process( + self, request_ix: RequestIX, response_ix: ResponseIX + ) -> ResponseIX: + ctx = response_ix.context + assert ctx is not None, "SetupStep must populate response_ix.context" + + pages = list(getattr(ctx, "pages", [])) + ocr_result = await self._client.ocr(pages) + + # Inject page tags around each OCR page's content so the LLM can + # cross-reference the visual anchor without a separate prompt hack. + page_metadata: list[PageMetadata] = list( + getattr(ctx, "page_metadata", []) or [] + ) + for idx, ocr_page in enumerate(ocr_result.result.pages): + meta = page_metadata[idx] if idx < len(page_metadata) else PageMetadata() + file_idx = meta.file_index if meta.file_index is not None else 0 + open_tag = Line( + text=f'', + bounding_box=[], + ) + close_tag = Line(text="", bounding_box=[]) + ocr_page.lines = [open_tag, *ocr_page.lines, close_tag] + + response_ix.ocr_result = ocr_result + + # Build SegmentIndex only when provenance is on. Segment IDs + # deliberately skip page-tag lines (see SegmentIndex.build). + if request_ix.options.provenance.include_provenance: + seg_idx = SegmentIndex.build( + ocr_result=ocr_result, + granularity="line", + pages_metadata=page_metadata, + ) + ctx.segment_index = seg_idx + + return response_ix + + +__all__ = ["OCRStep"] diff --git a/tests/unit/test_ocr_step.py b/tests/unit/test_ocr_step.py new file mode 100644 index 0000000..cc795c8 --- /dev/null +++ b/tests/unit/test_ocr_step.py @@ -0,0 +1,199 @@ +"""Tests for :class:`ix.pipeline.ocr_step.OCRStep` (spec §6.2).""" + +from __future__ import annotations + +import pytest + +from ix.contracts import ( + Context, + Line, + OCRDetails, + OCROptions, + OCRResult, + Options, + Page, + ProvenanceOptions, + RequestIX, + ResponseIX, +) +from ix.contracts.response import _InternalContext +from ix.errors import IXErrorCode, IXException +from ix.ocr import FakeOCRClient +from ix.pipeline.ocr_step import OCRStep +from ix.segmentation import PageMetadata, SegmentIndex + + +def _make_request( + *, + use_ocr: bool = True, + include_geometries: bool = False, + include_ocr_text: bool = False, + ocr_only: bool = False, + include_provenance: bool = True, + files: list | None = None, + texts: list[str] | None = None, +) -> RequestIX: + return RequestIX( + use_case="bank_statement_header", + ix_client_id="test", + request_id="r-1", + context=Context(files=files if files is not None else [], texts=texts or []), + options=Options( + ocr=OCROptions( + use_ocr=use_ocr, + include_geometries=include_geometries, + include_ocr_text=include_ocr_text, + ocr_only=ocr_only, + ), + provenance=ProvenanceOptions(include_provenance=include_provenance), + ), + ) + + +def _response_with_context( + *, + pages: list[Page] | None = None, + files: list | None = None, + texts: list[str] | None = None, + page_metadata: list[PageMetadata] | None = None, +) -> ResponseIX: + resp = ResponseIX() + resp.context = _InternalContext( + pages=pages or [], + files=files or [], + texts=texts or [], + page_metadata=page_metadata or [], + ) + return resp + + +def _canned_ocr(pages: int = 1) -> OCRResult: + return OCRResult( + result=OCRDetails( + text="\n".join(f"text p{i+1}" for i in range(pages)), + pages=[ + Page( + page_no=i + 1, + width=100.0, + height=200.0, + lines=[ + Line( + text=f"line-content p{i+1}", + bounding_box=[0, 0, 10, 0, 10, 5, 0, 5], + ) + ], + ) + for i in range(pages) + ], + ) + ) + + +class TestValidate: + async def test_ocr_only_without_files_raises_IX_000_004(self) -> None: + step = OCRStep(ocr_client=FakeOCRClient(canned=_canned_ocr())) + req = _make_request(ocr_only=True, files=[], texts=["hi"]) + resp = _response_with_context(files=[]) + with pytest.raises(IXException) as ei: + await step.validate(req, resp) + assert ei.value.code is IXErrorCode.IX_000_004 + + async def test_include_ocr_text_without_files_raises_IX_000_004(self) -> None: + step = OCRStep(ocr_client=FakeOCRClient(canned=_canned_ocr())) + req = _make_request(include_ocr_text=True, files=[], texts=["hi"]) + resp = _response_with_context(files=[]) + with pytest.raises(IXException) as ei: + await step.validate(req, resp) + assert ei.value.code is IXErrorCode.IX_000_004 + + async def test_include_geometries_without_files_raises_IX_000_004(self) -> None: + step = OCRStep(ocr_client=FakeOCRClient(canned=_canned_ocr())) + req = _make_request(include_geometries=True, files=[], texts=["hi"]) + resp = _response_with_context(files=[]) + with pytest.raises(IXException) as ei: + await step.validate(req, resp) + assert ei.value.code is IXErrorCode.IX_000_004 + + async def test_text_only_skips_step(self) -> None: + step = OCRStep(ocr_client=FakeOCRClient(canned=_canned_ocr())) + req = _make_request(use_ocr=True, files=[], texts=["hi"]) + resp = _response_with_context(files=[], texts=["hi"]) + assert await step.validate(req, resp) is False + + async def test_ocr_runs_when_files_and_use_ocr(self) -> None: + step = OCRStep(ocr_client=FakeOCRClient(canned=_canned_ocr())) + req = _make_request(use_ocr=True, files=["http://x"]) + resp = _response_with_context(files=[("/tmp/x.pdf", "application/pdf")]) + assert await step.validate(req, resp) is True + + +class TestProcess: + async def test_ocr_result_written_to_response(self) -> None: + canned = _canned_ocr(pages=2) + step = OCRStep(ocr_client=FakeOCRClient(canned=canned)) + req = _make_request(use_ocr=True, files=["http://x"]) + resp = _response_with_context( + pages=[Page(page_no=1, width=100.0, height=200.0, lines=[])] * 2, + files=[("/tmp/x.pdf", "application/pdf")], + page_metadata=[PageMetadata(file_index=0), PageMetadata(file_index=0)], + ) + resp = await step.process(req, resp) + # Full OCR result written. + assert resp.ocr_result.result.text == canned.result.text + # Page tags injected: prepend + append around lines per page. + pages = resp.ocr_result.result.pages + assert len(pages) == 2 + # First line is the opening tag. + assert pages[0].lines[0].text is not None + assert pages[0].lines[0].text.startswith(" tag. + assert pages[0].lines[-1].text == "" + + async def test_segment_index_built_when_provenance_on(self) -> None: + canned = _canned_ocr(pages=1) + step = OCRStep(ocr_client=FakeOCRClient(canned=canned)) + req = _make_request( + use_ocr=True, include_provenance=True, files=["http://x"] + ) + resp = _response_with_context( + pages=[Page(page_no=1, width=100.0, height=200.0, lines=[])], + files=[("/tmp/x.pdf", "application/pdf")], + page_metadata=[PageMetadata(file_index=0)], + ) + resp = await step.process(req, resp) + seg_idx = resp.context.segment_index # type: ignore[union-attr] + assert isinstance(seg_idx, SegmentIndex) + # Page-tag lines are excluded; only the real line becomes a segment. + assert seg_idx._ordered_ids == ["p1_l0"] + pos = seg_idx.lookup_segment("p1_l0") + assert pos is not None + assert pos["text"] == "line-content p1" + + async def test_segment_index_not_built_when_provenance_off(self) -> None: + canned = _canned_ocr(pages=1) + step = OCRStep(ocr_client=FakeOCRClient(canned=canned)) + req = _make_request( + use_ocr=True, include_provenance=False, files=["http://x"] + ) + resp = _response_with_context( + pages=[Page(page_no=1, width=100.0, height=200.0, lines=[])], + files=[("/tmp/x.pdf", "application/pdf")], + page_metadata=[PageMetadata(file_index=0)], + ) + resp = await step.process(req, resp) + assert resp.context.segment_index is None # type: ignore[union-attr] + + async def test_page_tags_include_file_index(self) -> None: + canned = _canned_ocr(pages=1) + step = OCRStep(ocr_client=FakeOCRClient(canned=canned)) + req = _make_request(use_ocr=True, files=["http://x"]) + resp = _response_with_context( + pages=[Page(page_no=1, width=100.0, height=200.0, lines=[])], + files=[("/tmp/x.pdf", "application/pdf")], + page_metadata=[PageMetadata(file_index=3)], + ) + resp = await step.process(req, resp) + first_line = resp.ocr_result.result.pages[0].lines[0].text + assert first_line is not None + assert 'file="3"' in first_line + assert 'number="1"' in first_line -- 2.45.2