feat(segmentation): SegmentIndex + prompt-text formatter (spec §9.1) #6

Merged
goldstein merged 1 commit from feat/segment-index into main 2026-04-18 08:54:03 +00:00
3 changed files with 337 additions and 0 deletions

View file

@ -0,0 +1,7 @@
"""Segment-index module: maps short IDs (``p1_l0``) to on-page anchors."""
from __future__ import annotations
from ix.segmentation.segment_index import PageMetadata, SegmentIndex
__all__ = ["PageMetadata", "SegmentIndex"]

View file

@ -0,0 +1,140 @@
"""SegmentIndex — maps short segment IDs (`p1_l0`) to their on-page anchors.
Per spec §9.1. Built from an :class:`OCRResult` + a parallel list of
:class:`PageMetadata` entries carrying the 0-based ``file_index`` for each
flat-list page. Used by two places:
1. :class:`GenAIStep` calls :meth:`SegmentIndex.to_prompt_text` to feed the
LLM a segment-tagged user message.
2. :func:`ix.provenance.mapper.map_segment_refs_to_provenance` calls
:meth:`SegmentIndex.lookup_segment` to resolve each LLM-cited segment ID
back to its page + bbox + text snippet.
Page-tag lines (``<page >`` / ``</page>``) emitted by the OCR step for
visual grounding are explicitly *excluded* from IDs so the LLM can never
cite them as provenance.
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any
from ix.contracts.provenance import BoundingBox
from ix.contracts.response import OCRResult
_PAGE_TAG_RE = re.compile(r"^\s*<\s*/?\s*page\b", re.IGNORECASE)
@dataclass(slots=True)
class PageMetadata:
"""Per-page metadata travelling alongside an :class:`OCRResult`.
The OCR engine doesn't know which input file a page came from — that
mapping lives in the pipeline's internal context. Passing it in
explicitly keeps the segmentation module decoupled from the ingestion
module.
"""
file_index: int | None = None
def _is_page_tag(text: str | None) -> bool:
if not text:
return False
return bool(_PAGE_TAG_RE.match(text))
def _normalize_bbox(coords: list[float], width: float, height: float) -> BoundingBox:
"""8-coord polygon → 0-1 normalised. Zero dimensions → pass through."""
if width <= 0 or height <= 0 or len(coords) != 8:
return BoundingBox(coordinates=list(coords))
normalised = [
coords[0] / width,
coords[1] / height,
coords[2] / width,
coords[3] / height,
coords[4] / width,
coords[5] / height,
coords[6] / width,
coords[7] / height,
]
return BoundingBox(coordinates=normalised)
@dataclass(slots=True)
class SegmentIndex:
"""Read-mostly segment lookup built once per request.
``_id_to_position`` is a flat dict for O(1) lookup; ``_ordered_ids``
preserves insertion order so ``to_prompt_text`` produces a stable
deterministic rendering (matters for prompt caching).
"""
granularity: str = "line"
_id_to_position: dict[str, dict[str, Any]] = field(default_factory=dict)
_ordered_ids: list[str] = field(default_factory=list)
@classmethod
def build(
cls,
ocr_result: OCRResult,
*,
granularity: str = "line",
pages_metadata: list[PageMetadata],
) -> SegmentIndex:
"""Construct a new index from an OCR result + per-page metadata.
``pages_metadata`` must have the same length as
``ocr_result.result.pages`` and in the same order. The pipeline
builds both in :class:`SetupStep` / :class:`OCRStep` before calling
here.
"""
idx = cls(granularity=granularity)
pages = ocr_result.result.pages
for global_pos, page in enumerate(pages, start=1):
meta = pages_metadata[global_pos - 1] if global_pos - 1 < len(pages_metadata) else PageMetadata()
line_idx_in_page = 0
for line in page.lines:
if _is_page_tag(line.text):
continue
seg_id = f"p{global_pos}_l{line_idx_in_page}"
bbox = _normalize_bbox(
list(line.bounding_box),
float(page.width),
float(page.height),
)
idx._id_to_position[seg_id] = {
"page": global_pos,
"bbox": bbox,
"text": line.text or "",
"file_index": meta.file_index,
}
idx._ordered_ids.append(seg_id)
line_idx_in_page += 1
return idx
def lookup_segment(self, segment_id: str) -> dict[str, Any] | None:
"""Return the stored position dict, or ``None`` if the ID is unknown."""
return self._id_to_position.get(segment_id)
def to_prompt_text(self, context_texts: list[str] | None = None) -> str:
"""Emit the LLM-facing user text.
Format (deterministic):
- One ``[pN_lM] text`` line per indexed segment, in insertion order.
- If ``context_texts`` is non-empty: a blank line, then each entry on
its own line, untagged. This mirrors spec §7.2b pre-OCR text
(Paperless, etc.) is added as untagged context for the model to
cross-reference against the tagged document body.
"""
lines: list[str] = [
f"[{seg_id}] {self._id_to_position[seg_id]['text']}" for seg_id in self._ordered_ids
]
out = "\n".join(lines)
if context_texts:
suffix = "\n\n".join(context_texts)
out = f"{out}\n\n{suffix}" if out else suffix
return out

View file

@ -0,0 +1,190 @@
"""Tests for the SegmentIndex — spec §9.1."""
from __future__ import annotations
from ix.contracts import Line, OCRDetails, OCRResult, Page
from ix.segmentation import PageMetadata, SegmentIndex
def _make_pages_metadata(n: int, file_index: int = 0) -> list[PageMetadata]:
"""Build ``n`` flat-list page entries carrying only file_index."""
return [PageMetadata(file_index=file_index) for _ in range(n)]
def _line(text: str, bbox: list[float]) -> Line:
return Line(text=text, bounding_box=bbox)
def _page(page_no: int, width: float, height: float, lines: list[Line]) -> Page:
return Page(page_no=page_no, width=width, height=height, lines=lines)
class TestBuild:
def test_ids_per_page(self) -> None:
ocr = OCRResult(
result=OCRDetails(
pages=[
_page(1, 100.0, 200.0, [_line("hello", [0, 0, 10, 0, 10, 20, 0, 20])]),
_page(
2,
100.0,
200.0,
[
_line("foo", [0, 0, 10, 0, 10, 20, 0, 20]),
_line("bar", [0, 30, 10, 30, 10, 50, 0, 50]),
],
),
]
)
)
idx = SegmentIndex.build(
ocr_result=ocr,
granularity="line",
pages_metadata=_make_pages_metadata(2),
)
assert idx._ordered_ids == ["p1_l0", "p2_l0", "p2_l1"]
pos = idx.lookup_segment("p1_l0")
assert pos is not None
assert pos["page"] == 1
assert pos["text"] == "hello"
assert pos["file_index"] == 0
def test_page_tag_lines_excluded(self) -> None:
ocr = OCRResult(
result=OCRDetails(
pages=[
_page(
1,
100.0,
200.0,
[
_line('<page file="0" number="1">', [0, 0, 10, 0, 10, 5, 0, 5]),
_line("first real line", [0, 10, 10, 10, 10, 20, 0, 20]),
_line("</page>", [0, 25, 10, 25, 10, 30, 0, 30]),
],
)
]
)
)
idx = SegmentIndex.build(
ocr_result=ocr,
granularity="line",
pages_metadata=_make_pages_metadata(1),
)
assert idx._ordered_ids == ["p1_l0"]
assert idx.lookup_segment("p1_l0")["text"] == "first real line" # type: ignore[index]
def test_lookup_unknown_returns_none(self) -> None:
idx = SegmentIndex.build(
ocr_result=OCRResult(result=OCRDetails(pages=[])),
granularity="line",
pages_metadata=[],
)
assert idx.lookup_segment("pX_l99") is None
class TestBboxNormalization:
def test_divides_by_page_width_and_height(self) -> None:
# x-coords get /width, y-coords get /height.
ocr = OCRResult(
result=OCRDetails(
pages=[
_page(
1,
width=200.0,
height=400.0,
lines=[_line("x", [50, 100, 150, 100, 150, 300, 50, 300])],
)
]
)
)
idx = SegmentIndex.build(
ocr_result=ocr,
granularity="line",
pages_metadata=_make_pages_metadata(1),
)
pos = idx.lookup_segment("p1_l0")
assert pos is not None
bbox = pos["bbox"]
# Compare with a bit of float slack.
assert bbox.coordinates == [0.25, 0.25, 0.75, 0.25, 0.75, 0.75, 0.25, 0.75]
class TestPromptFormat:
def test_tagged_lines_and_untagged_texts_appended(self) -> None:
ocr = OCRResult(
result=OCRDetails(
pages=[
_page(
1,
100.0,
200.0,
[
_line("line one", [0, 0, 10, 0, 10, 5, 0, 5]),
_line("line two", [0, 10, 10, 10, 10, 15, 0, 15]),
],
),
_page(2, 100.0, 200.0, [_line("line A", [0, 0, 10, 0, 10, 5, 0, 5])]),
]
)
)
idx = SegmentIndex.build(
ocr_result=ocr,
granularity="line",
pages_metadata=_make_pages_metadata(2),
)
text = idx.to_prompt_text(context_texts=["extra paperless text", "another"])
lines = text.split("\n")
# Tagged OCR lines first, in insertion order.
assert lines[0] == "[p1_l0] line one"
assert lines[1] == "[p1_l1] line two"
assert lines[2] == "[p2_l0] line A"
# The extra texts are appended untagged.
assert "extra paperless text" in text
assert "another" in text
# Sanity: the pageless texts should appear AFTER the last tagged line.
p2_idx = text.index("[p2_l0]")
extra_idx = text.index("extra paperless text")
assert extra_idx > p2_idx
def test_prompt_text_without_extra_texts(self) -> None:
ocr = OCRResult(
result=OCRDetails(
pages=[
_page(1, 100.0, 200.0, [_line("only", [0, 0, 10, 0, 10, 5, 0, 5])]),
]
)
)
idx = SegmentIndex.build(
ocr_result=ocr,
granularity="line",
pages_metadata=_make_pages_metadata(1),
)
text = idx.to_prompt_text(context_texts=[])
assert text.strip() == "[p1_l0] only"
class TestFileIndexPassthrough:
def test_file_index_from_metadata(self) -> None:
pages_meta = [
PageMetadata(file_index=0),
PageMetadata(file_index=1),
]
ocr = OCRResult(
result=OCRDetails(
pages=[
_page(1, 100.0, 200.0, [_line("a", [0, 0, 10, 0, 10, 5, 0, 5])]),
_page(1, 100.0, 200.0, [_line("b", [0, 0, 10, 0, 10, 5, 0, 5])]),
]
)
)
idx = SegmentIndex.build(
ocr_result=ocr,
granularity="line",
pages_metadata=pages_meta,
)
assert idx.lookup_segment("p1_l0")["file_index"] == 0 # type: ignore[index]
assert idx.lookup_segment("p2_l0")["file_index"] == 1 # type: ignore[index]