feat(segmentation): SegmentIndex + prompt-text formatter (spec §9.1) (#6)
Some checks are pending
tests / test (push) Waiting to run
Some checks are pending
tests / test (push) Waiting to run
SegmentIndex lands.
This commit is contained in:
commit
b2ff27c1ca
3 changed files with 337 additions and 0 deletions
7
src/ix/segmentation/__init__.py
Normal file
7
src/ix/segmentation/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""Segment-index module: maps short IDs (``p1_l0``) to on-page anchors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ix.segmentation.segment_index import PageMetadata, SegmentIndex
|
||||
|
||||
__all__ = ["PageMetadata", "SegmentIndex"]
|
||||
140
src/ix/segmentation/segment_index.py
Normal file
140
src/ix/segmentation/segment_index.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
"""SegmentIndex — maps short segment IDs (`p1_l0`) to their on-page anchors.
|
||||
|
||||
Per spec §9.1. Built from an :class:`OCRResult` + a parallel list of
|
||||
:class:`PageMetadata` entries carrying the 0-based ``file_index`` for each
|
||||
flat-list page. Used by two places:
|
||||
|
||||
1. :class:`GenAIStep` calls :meth:`SegmentIndex.to_prompt_text` to feed the
|
||||
LLM a segment-tagged user message.
|
||||
2. :func:`ix.provenance.mapper.map_segment_refs_to_provenance` calls
|
||||
:meth:`SegmentIndex.lookup_segment` to resolve each LLM-cited segment ID
|
||||
back to its page + bbox + text snippet.
|
||||
|
||||
Page-tag lines (``<page …>`` / ``</page>``) emitted by the OCR step for
|
||||
visual grounding are explicitly *excluded* from IDs so the LLM can never
|
||||
cite them as provenance.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ix.contracts.provenance import BoundingBox
|
||||
from ix.contracts.response import OCRResult
|
||||
|
||||
_PAGE_TAG_RE = re.compile(r"^\s*<\s*/?\s*page\b", re.IGNORECASE)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PageMetadata:
|
||||
"""Per-page metadata travelling alongside an :class:`OCRResult`.
|
||||
|
||||
The OCR engine doesn't know which input file a page came from — that
|
||||
mapping lives in the pipeline's internal context. Passing it in
|
||||
explicitly keeps the segmentation module decoupled from the ingestion
|
||||
module.
|
||||
"""
|
||||
|
||||
file_index: int | None = None
|
||||
|
||||
|
||||
def _is_page_tag(text: str | None) -> bool:
|
||||
if not text:
|
||||
return False
|
||||
return bool(_PAGE_TAG_RE.match(text))
|
||||
|
||||
|
||||
def _normalize_bbox(coords: list[float], width: float, height: float) -> BoundingBox:
|
||||
"""8-coord polygon → 0-1 normalised. Zero dimensions → pass through."""
|
||||
if width <= 0 or height <= 0 or len(coords) != 8:
|
||||
return BoundingBox(coordinates=list(coords))
|
||||
normalised = [
|
||||
coords[0] / width,
|
||||
coords[1] / height,
|
||||
coords[2] / width,
|
||||
coords[3] / height,
|
||||
coords[4] / width,
|
||||
coords[5] / height,
|
||||
coords[6] / width,
|
||||
coords[7] / height,
|
||||
]
|
||||
return BoundingBox(coordinates=normalised)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SegmentIndex:
|
||||
"""Read-mostly segment lookup built once per request.
|
||||
|
||||
``_id_to_position`` is a flat dict for O(1) lookup; ``_ordered_ids``
|
||||
preserves insertion order so ``to_prompt_text`` produces a stable
|
||||
deterministic rendering (matters for prompt caching).
|
||||
"""
|
||||
|
||||
granularity: str = "line"
|
||||
_id_to_position: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
_ordered_ids: list[str] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
ocr_result: OCRResult,
|
||||
*,
|
||||
granularity: str = "line",
|
||||
pages_metadata: list[PageMetadata],
|
||||
) -> SegmentIndex:
|
||||
"""Construct a new index from an OCR result + per-page metadata.
|
||||
|
||||
``pages_metadata`` must have the same length as
|
||||
``ocr_result.result.pages`` and in the same order. The pipeline
|
||||
builds both in :class:`SetupStep` / :class:`OCRStep` before calling
|
||||
here.
|
||||
"""
|
||||
idx = cls(granularity=granularity)
|
||||
pages = ocr_result.result.pages
|
||||
for global_pos, page in enumerate(pages, start=1):
|
||||
meta = pages_metadata[global_pos - 1] if global_pos - 1 < len(pages_metadata) else PageMetadata()
|
||||
line_idx_in_page = 0
|
||||
for line in page.lines:
|
||||
if _is_page_tag(line.text):
|
||||
continue
|
||||
seg_id = f"p{global_pos}_l{line_idx_in_page}"
|
||||
bbox = _normalize_bbox(
|
||||
list(line.bounding_box),
|
||||
float(page.width),
|
||||
float(page.height),
|
||||
)
|
||||
idx._id_to_position[seg_id] = {
|
||||
"page": global_pos,
|
||||
"bbox": bbox,
|
||||
"text": line.text or "",
|
||||
"file_index": meta.file_index,
|
||||
}
|
||||
idx._ordered_ids.append(seg_id)
|
||||
line_idx_in_page += 1
|
||||
return idx
|
||||
|
||||
def lookup_segment(self, segment_id: str) -> dict[str, Any] | None:
|
||||
"""Return the stored position dict, or ``None`` if the ID is unknown."""
|
||||
return self._id_to_position.get(segment_id)
|
||||
|
||||
def to_prompt_text(self, context_texts: list[str] | None = None) -> str:
|
||||
"""Emit the LLM-facing user text.
|
||||
|
||||
Format (deterministic):
|
||||
|
||||
- One ``[pN_lM] text`` line per indexed segment, in insertion order.
|
||||
- If ``context_texts`` is non-empty: a blank line, then each entry on
|
||||
its own line, untagged. This mirrors spec §7.2b — pre-OCR text
|
||||
(Paperless, etc.) is added as untagged context for the model to
|
||||
cross-reference against the tagged document body.
|
||||
"""
|
||||
lines: list[str] = [
|
||||
f"[{seg_id}] {self._id_to_position[seg_id]['text']}" for seg_id in self._ordered_ids
|
||||
]
|
||||
out = "\n".join(lines)
|
||||
if context_texts:
|
||||
suffix = "\n\n".join(context_texts)
|
||||
out = f"{out}\n\n{suffix}" if out else suffix
|
||||
return out
|
||||
190
tests/unit/test_segment_index.py
Normal file
190
tests/unit/test_segment_index.py
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
"""Tests for the SegmentIndex — spec §9.1."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ix.contracts import Line, OCRDetails, OCRResult, Page
|
||||
from ix.segmentation import PageMetadata, SegmentIndex
|
||||
|
||||
|
||||
def _make_pages_metadata(n: int, file_index: int = 0) -> list[PageMetadata]:
|
||||
"""Build ``n`` flat-list page entries carrying only file_index."""
|
||||
return [PageMetadata(file_index=file_index) for _ in range(n)]
|
||||
|
||||
|
||||
def _line(text: str, bbox: list[float]) -> Line:
|
||||
return Line(text=text, bounding_box=bbox)
|
||||
|
||||
|
||||
def _page(page_no: int, width: float, height: float, lines: list[Line]) -> Page:
|
||||
return Page(page_no=page_no, width=width, height=height, lines=lines)
|
||||
|
||||
|
||||
class TestBuild:
|
||||
def test_ids_per_page(self) -> None:
|
||||
ocr = OCRResult(
|
||||
result=OCRDetails(
|
||||
pages=[
|
||||
_page(1, 100.0, 200.0, [_line("hello", [0, 0, 10, 0, 10, 20, 0, 20])]),
|
||||
_page(
|
||||
2,
|
||||
100.0,
|
||||
200.0,
|
||||
[
|
||||
_line("foo", [0, 0, 10, 0, 10, 20, 0, 20]),
|
||||
_line("bar", [0, 30, 10, 30, 10, 50, 0, 50]),
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
idx = SegmentIndex.build(
|
||||
ocr_result=ocr,
|
||||
granularity="line",
|
||||
pages_metadata=_make_pages_metadata(2),
|
||||
)
|
||||
|
||||
assert idx._ordered_ids == ["p1_l0", "p2_l0", "p2_l1"]
|
||||
|
||||
pos = idx.lookup_segment("p1_l0")
|
||||
assert pos is not None
|
||||
assert pos["page"] == 1
|
||||
assert pos["text"] == "hello"
|
||||
assert pos["file_index"] == 0
|
||||
|
||||
def test_page_tag_lines_excluded(self) -> None:
|
||||
ocr = OCRResult(
|
||||
result=OCRDetails(
|
||||
pages=[
|
||||
_page(
|
||||
1,
|
||||
100.0,
|
||||
200.0,
|
||||
[
|
||||
_line('<page file="0" number="1">', [0, 0, 10, 0, 10, 5, 0, 5]),
|
||||
_line("first real line", [0, 10, 10, 10, 10, 20, 0, 20]),
|
||||
_line("</page>", [0, 25, 10, 25, 10, 30, 0, 30]),
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
idx = SegmentIndex.build(
|
||||
ocr_result=ocr,
|
||||
granularity="line",
|
||||
pages_metadata=_make_pages_metadata(1),
|
||||
)
|
||||
assert idx._ordered_ids == ["p1_l0"]
|
||||
assert idx.lookup_segment("p1_l0")["text"] == "first real line" # type: ignore[index]
|
||||
|
||||
def test_lookup_unknown_returns_none(self) -> None:
|
||||
idx = SegmentIndex.build(
|
||||
ocr_result=OCRResult(result=OCRDetails(pages=[])),
|
||||
granularity="line",
|
||||
pages_metadata=[],
|
||||
)
|
||||
assert idx.lookup_segment("pX_l99") is None
|
||||
|
||||
|
||||
class TestBboxNormalization:
|
||||
def test_divides_by_page_width_and_height(self) -> None:
|
||||
# x-coords get /width, y-coords get /height.
|
||||
ocr = OCRResult(
|
||||
result=OCRDetails(
|
||||
pages=[
|
||||
_page(
|
||||
1,
|
||||
width=200.0,
|
||||
height=400.0,
|
||||
lines=[_line("x", [50, 100, 150, 100, 150, 300, 50, 300])],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
idx = SegmentIndex.build(
|
||||
ocr_result=ocr,
|
||||
granularity="line",
|
||||
pages_metadata=_make_pages_metadata(1),
|
||||
)
|
||||
pos = idx.lookup_segment("p1_l0")
|
||||
assert pos is not None
|
||||
bbox = pos["bbox"]
|
||||
# Compare with a bit of float slack.
|
||||
assert bbox.coordinates == [0.25, 0.25, 0.75, 0.25, 0.75, 0.75, 0.25, 0.75]
|
||||
|
||||
|
||||
class TestPromptFormat:
|
||||
def test_tagged_lines_and_untagged_texts_appended(self) -> None:
|
||||
ocr = OCRResult(
|
||||
result=OCRDetails(
|
||||
pages=[
|
||||
_page(
|
||||
1,
|
||||
100.0,
|
||||
200.0,
|
||||
[
|
||||
_line("line one", [0, 0, 10, 0, 10, 5, 0, 5]),
|
||||
_line("line two", [0, 10, 10, 10, 10, 15, 0, 15]),
|
||||
],
|
||||
),
|
||||
_page(2, 100.0, 200.0, [_line("line A", [0, 0, 10, 0, 10, 5, 0, 5])]),
|
||||
]
|
||||
)
|
||||
)
|
||||
idx = SegmentIndex.build(
|
||||
ocr_result=ocr,
|
||||
granularity="line",
|
||||
pages_metadata=_make_pages_metadata(2),
|
||||
)
|
||||
text = idx.to_prompt_text(context_texts=["extra paperless text", "another"])
|
||||
|
||||
lines = text.split("\n")
|
||||
# Tagged OCR lines first, in insertion order.
|
||||
assert lines[0] == "[p1_l0] line one"
|
||||
assert lines[1] == "[p1_l1] line two"
|
||||
assert lines[2] == "[p2_l0] line A"
|
||||
# The extra texts are appended untagged.
|
||||
assert "extra paperless text" in text
|
||||
assert "another" in text
|
||||
# Sanity: the pageless texts should appear AFTER the last tagged line.
|
||||
p2_idx = text.index("[p2_l0]")
|
||||
extra_idx = text.index("extra paperless text")
|
||||
assert extra_idx > p2_idx
|
||||
|
||||
def test_prompt_text_without_extra_texts(self) -> None:
|
||||
ocr = OCRResult(
|
||||
result=OCRDetails(
|
||||
pages=[
|
||||
_page(1, 100.0, 200.0, [_line("only", [0, 0, 10, 0, 10, 5, 0, 5])]),
|
||||
]
|
||||
)
|
||||
)
|
||||
idx = SegmentIndex.build(
|
||||
ocr_result=ocr,
|
||||
granularity="line",
|
||||
pages_metadata=_make_pages_metadata(1),
|
||||
)
|
||||
text = idx.to_prompt_text(context_texts=[])
|
||||
assert text.strip() == "[p1_l0] only"
|
||||
|
||||
|
||||
class TestFileIndexPassthrough:
|
||||
def test_file_index_from_metadata(self) -> None:
|
||||
pages_meta = [
|
||||
PageMetadata(file_index=0),
|
||||
PageMetadata(file_index=1),
|
||||
]
|
||||
ocr = OCRResult(
|
||||
result=OCRDetails(
|
||||
pages=[
|
||||
_page(1, 100.0, 200.0, [_line("a", [0, 0, 10, 0, 10, 5, 0, 5])]),
|
||||
_page(1, 100.0, 200.0, [_line("b", [0, 0, 10, 0, 10, 5, 0, 5])]),
|
||||
]
|
||||
)
|
||||
)
|
||||
idx = SegmentIndex.build(
|
||||
ocr_result=ocr,
|
||||
granularity="line",
|
||||
pages_metadata=pages_meta,
|
||||
)
|
||||
assert idx.lookup_segment("p1_l0")["file_index"] == 0 # type: ignore[index]
|
||||
assert idx.lookup_segment("p2_l0")["file_index"] == 1 # type: ignore[index]
|
||||
Loading…
Reference in a new issue