Compare commits
2 commits
2d22115893
...
b397a80c0b
| Author | SHA1 | Date | |
|---|---|---|---|
| b397a80c0b | |||
| 1e340c82fa |
5 changed files with 811 additions and 0 deletions
|
|
@ -15,6 +15,10 @@ and verifier land in task 1.8.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from ix.provenance.mapper import (
|
||||||
|
map_segment_refs_to_provenance,
|
||||||
|
resolve_nested_path,
|
||||||
|
)
|
||||||
from ix.provenance.normalize import (
|
from ix.provenance.normalize import (
|
||||||
normalize_date,
|
normalize_date,
|
||||||
normalize_iban,
|
normalize_iban,
|
||||||
|
|
@ -22,11 +26,16 @@ from ix.provenance.normalize import (
|
||||||
normalize_string,
|
normalize_string,
|
||||||
should_skip_text_agreement,
|
should_skip_text_agreement,
|
||||||
)
|
)
|
||||||
|
from ix.provenance.verify import apply_reliability_flags, verify_field
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"apply_reliability_flags",
|
||||||
|
"map_segment_refs_to_provenance",
|
||||||
"normalize_date",
|
"normalize_date",
|
||||||
"normalize_iban",
|
"normalize_iban",
|
||||||
"normalize_number",
|
"normalize_number",
|
||||||
"normalize_string",
|
"normalize_string",
|
||||||
|
"resolve_nested_path",
|
||||||
"should_skip_text_agreement",
|
"should_skip_text_agreement",
|
||||||
|
"verify_field",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
145
src/ix/provenance/mapper.py
Normal file
145
src/ix/provenance/mapper.py
Normal file
|
|
@ -0,0 +1,145 @@
|
||||||
|
"""Maps LLM-emitted :class:`SegmentCitation` lists to :class:`ProvenanceData`.
|
||||||
|
|
||||||
|
Implements spec §9.4. The algorithm is deliberately small:
|
||||||
|
|
||||||
|
1. For each citation, pick the seg-id list (``value`` vs. ``value_and_context``).
|
||||||
|
2. Cap at ``max_sources_per_field``.
|
||||||
|
3. Resolve each ID via :meth:`SegmentIndex.lookup_segment`; count misses.
|
||||||
|
4. Resolve the field's value by dot-path traversal of the extraction result.
|
||||||
|
5. Build a :class:`FieldProvenance`. Skip fields that resolved to zero sources.
|
||||||
|
|
||||||
|
No verification / normalisation happens here — this module's sole job is
|
||||||
|
structural assembly. :mod:`ix.provenance.verify` does the reliability pass
|
||||||
|
downstream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from ix.contracts.provenance import (
|
||||||
|
ExtractionSource,
|
||||||
|
FieldProvenance,
|
||||||
|
ProvenanceData,
|
||||||
|
SegmentCitation,
|
||||||
|
)
|
||||||
|
from ix.segmentation import SegmentIndex
|
||||||
|
|
||||||
|
SourceType = Literal["value", "value_and_context"]
|
||||||
|
|
||||||
|
|
||||||
|
_BRACKET_RE = re.compile(r"\[(\d+)\]")
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_nested_path(data: Any, path: str) -> Any:
|
||||||
|
"""Resolve a dot-path into ``data`` with ``[N]`` array notation normalised.
|
||||||
|
|
||||||
|
``"result.items[0].name"`` → walks ``data["result"]["items"][0]["name"]``.
|
||||||
|
Returns ``None`` at any missing-key / index-out-of-range step so callers
|
||||||
|
can fall back to recording the field with a null value.
|
||||||
|
"""
|
||||||
|
normalised = _BRACKET_RE.sub(r".\1", path)
|
||||||
|
cur: Any = data
|
||||||
|
for part in normalised.split("."):
|
||||||
|
if cur is None:
|
||||||
|
return None
|
||||||
|
if part.isdigit() and isinstance(cur, list):
|
||||||
|
i = int(part)
|
||||||
|
if i < 0 or i >= len(cur):
|
||||||
|
return None
|
||||||
|
cur = cur[i]
|
||||||
|
elif isinstance(cur, dict):
|
||||||
|
cur = cur.get(part)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
return cur
|
||||||
|
|
||||||
|
|
||||||
|
def _segment_ids_for_citation(
|
||||||
|
citation: SegmentCitation,
|
||||||
|
source_type: SourceType,
|
||||||
|
) -> list[str]:
|
||||||
|
if source_type == "value":
|
||||||
|
return list(citation.value_segment_ids)
|
||||||
|
# value_and_context
|
||||||
|
return list(citation.value_segment_ids) + list(citation.context_segment_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def map_segment_refs_to_provenance(
|
||||||
|
extraction_result: dict[str, Any],
|
||||||
|
segment_citations: list[SegmentCitation],
|
||||||
|
segment_index: SegmentIndex,
|
||||||
|
max_sources_per_field: int,
|
||||||
|
min_confidence: float, # reserved (no-op for MVP)
|
||||||
|
include_bounding_boxes: bool,
|
||||||
|
source_type: SourceType,
|
||||||
|
) -> ProvenanceData:
|
||||||
|
"""Build a :class:`ProvenanceData` from LLM citations and a SegmentIndex."""
|
||||||
|
# min_confidence is reserved for future use (see spec §2 provenance options).
|
||||||
|
_ = min_confidence
|
||||||
|
|
||||||
|
fields: dict[str, FieldProvenance] = {}
|
||||||
|
invalid_references = 0
|
||||||
|
|
||||||
|
for citation in segment_citations:
|
||||||
|
seg_ids = _segment_ids_for_citation(citation, source_type)[:max_sources_per_field]
|
||||||
|
sources: list[ExtractionSource] = []
|
||||||
|
for seg_id in seg_ids:
|
||||||
|
pos = segment_index.lookup_segment(seg_id)
|
||||||
|
if pos is None:
|
||||||
|
invalid_references += 1
|
||||||
|
continue
|
||||||
|
sources.append(
|
||||||
|
ExtractionSource(
|
||||||
|
page_number=pos["page"],
|
||||||
|
file_index=pos.get("file_index"),
|
||||||
|
bounding_box=pos["bbox"] if include_bounding_boxes else None,
|
||||||
|
text_snippet=pos["text"],
|
||||||
|
relevance_score=1.0,
|
||||||
|
segment_id=seg_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not sources:
|
||||||
|
continue
|
||||||
|
|
||||||
|
value = resolve_nested_path(extraction_result, citation.field_path)
|
||||||
|
fields[citation.field_path] = FieldProvenance(
|
||||||
|
field_name=citation.field_path.split(".")[-1],
|
||||||
|
field_path=citation.field_path,
|
||||||
|
value=value,
|
||||||
|
sources=sources,
|
||||||
|
confidence=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
total_fields_in_result = _count_leaf_fields(extraction_result.get("result", {}))
|
||||||
|
coverage_rate: float | None = None
|
||||||
|
if total_fields_in_result > 0:
|
||||||
|
coverage_rate = len(fields) / total_fields_in_result
|
||||||
|
|
||||||
|
return ProvenanceData(
|
||||||
|
fields=fields,
|
||||||
|
quality_metrics={
|
||||||
|
"fields_with_provenance": len(fields),
|
||||||
|
"total_fields": total_fields_in_result or None,
|
||||||
|
"coverage_rate": coverage_rate,
|
||||||
|
"invalid_references": invalid_references,
|
||||||
|
},
|
||||||
|
segment_count=len(segment_index._ordered_ids),
|
||||||
|
granularity=segment_index.granularity,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_leaf_fields(data: Any) -> int:
|
||||||
|
"""Count non-container leaves (str/int/float/Decimal/date/bool/None) recursively."""
|
||||||
|
if data is None:
|
||||||
|
return 1
|
||||||
|
if isinstance(data, dict):
|
||||||
|
if not data:
|
||||||
|
return 0
|
||||||
|
return sum(_count_leaf_fields(v) for v in data.values())
|
||||||
|
if isinstance(data, list):
|
||||||
|
if not data:
|
||||||
|
return 0
|
||||||
|
return sum(_count_leaf_fields(v) for v in data)
|
||||||
|
return 1
|
||||||
231
src/ix/provenance/verify.py
Normal file
231
src/ix/provenance/verify.py
Normal file
|
|
@ -0,0 +1,231 @@
|
||||||
|
"""Reliability verifier — writes `provenance_verified` + `text_agreement`.
|
||||||
|
|
||||||
|
Implements the dispatch table from spec §6 ReliabilityStep. The normalisers
|
||||||
|
in :mod:`ix.provenance.normalize` do the actual string/number/date work; this
|
||||||
|
module chooses which one to run for each field based on its Pydantic
|
||||||
|
annotation, and writes the two reliability flags onto the existing
|
||||||
|
:class:`FieldProvenance` records in place.
|
||||||
|
|
||||||
|
The summary counters (``verified_fields``, ``text_agreement_fields``) land in
|
||||||
|
``provenance.quality_metrics`` alongside the existing coverage metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
import types
|
||||||
|
from datetime import date, datetime
|
||||||
|
from decimal import Decimal, InvalidOperation
|
||||||
|
from typing import Any, Literal, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ix.contracts.provenance import FieldProvenance, ProvenanceData
|
||||||
|
from ix.provenance.normalize import (
|
||||||
|
normalize_date,
|
||||||
|
normalize_iban,
|
||||||
|
normalize_number,
|
||||||
|
normalize_string,
|
||||||
|
should_skip_text_agreement,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _unwrap_optional(tp: Any) -> Any:
|
||||||
|
"""Return the non-None arm of ``Optional[X]`` / ``X | None``.
|
||||||
|
|
||||||
|
Handles both ``typing.Union`` and ``types.UnionType`` (PEP 604 ``X | Y``
|
||||||
|
unions), which ``get_type_hints`` returns on Python 3.12+.
|
||||||
|
"""
|
||||||
|
origin = get_origin(tp)
|
||||||
|
if origin is Union or origin is types.UnionType:
|
||||||
|
args = [a for a in get_args(tp) if a is not type(None)]
|
||||||
|
if len(args) == 1:
|
||||||
|
return args[0]
|
||||||
|
return tp
|
||||||
|
|
||||||
|
|
||||||
|
def _is_literal(tp: Any) -> bool:
|
||||||
|
return get_origin(tp) is Literal
|
||||||
|
|
||||||
|
|
||||||
|
def _is_iban_field(field_path: str) -> bool:
|
||||||
|
return "iban" in field_path.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_string(value: str, snippet: str) -> bool:
|
||||||
|
try:
|
||||||
|
return normalize_string(value) in normalize_string(snippet)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
_NUMBER_TOKEN_RE = re.compile(r"[+\-]?[\d][\d',.\s]*\d|[+\-]?\d+")
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_number(value: Any, snippet: str) -> bool:
|
||||||
|
try:
|
||||||
|
canonical = normalize_number(value)
|
||||||
|
except (InvalidOperation, ValueError):
|
||||||
|
return False
|
||||||
|
# Try the whole snippet first (cheap path when the snippet IS the number).
|
||||||
|
try:
|
||||||
|
if normalize_number(snippet) == canonical:
|
||||||
|
return True
|
||||||
|
except (InvalidOperation, ValueError):
|
||||||
|
pass
|
||||||
|
# Fall back to scanning numeric substrings — OCR snippets commonly carry
|
||||||
|
# labels ("Closing balance CHF 1'234.56") that confuse a whole-string
|
||||||
|
# numeric parse.
|
||||||
|
for match in _NUMBER_TOKEN_RE.finditer(snippet):
|
||||||
|
token = match.group()
|
||||||
|
try:
|
||||||
|
if normalize_number(token) == canonical:
|
||||||
|
return True
|
||||||
|
except (InvalidOperation, ValueError):
|
||||||
|
continue
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_date(value: Any, snippet: str) -> bool:
|
||||||
|
try:
|
||||||
|
iso_value = normalize_date(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return False
|
||||||
|
# Find any date-like chunk in snippet; try normalising each token segment.
|
||||||
|
# Simplest heuristic: try snippet as a whole; on failure, scan tokens.
|
||||||
|
try:
|
||||||
|
if normalize_date(snippet) == iso_value:
|
||||||
|
return True
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
# Token scan — dateutil will raise on the non-date tokens, which is fine.
|
||||||
|
for token in _tokenise_for_date(snippet):
|
||||||
|
try:
|
||||||
|
if normalize_date(token) == iso_value:
|
||||||
|
return True
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
continue
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _tokenise_for_date(s: str) -> list[str]:
|
||||||
|
"""Split on whitespace + common punctuation so date strings survive whole.
|
||||||
|
|
||||||
|
Keeps dots / slashes / dashes inside tokens (they're valid date
|
||||||
|
separators); splits on spaces, commas, semicolons, colons, brackets.
|
||||||
|
"""
|
||||||
|
return [t for t in re.split(r"[\s,;:()\[\]]+", s) if t]
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_iban(value: str, snippet: str) -> bool:
|
||||||
|
try:
|
||||||
|
return normalize_iban(value) in normalize_iban(snippet)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_for_type(value: Any, field_type: Any, snippet: str, field_path: str) -> bool:
|
||||||
|
unwrapped = _unwrap_optional(field_type)
|
||||||
|
|
||||||
|
# Date / datetime.
|
||||||
|
if unwrapped in (date, datetime):
|
||||||
|
return _compare_date(value, snippet)
|
||||||
|
|
||||||
|
# Numeric.
|
||||||
|
if unwrapped in (int, float, Decimal):
|
||||||
|
return _compare_number(value, snippet)
|
||||||
|
|
||||||
|
# IBAN (detected by field name).
|
||||||
|
if _is_iban_field(field_path):
|
||||||
|
return _compare_iban(str(value), snippet)
|
||||||
|
|
||||||
|
# Default: string substring.
|
||||||
|
return _compare_string(str(value), snippet)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_field(
|
||||||
|
field_provenance: FieldProvenance,
|
||||||
|
field_type: Any,
|
||||||
|
texts: list[str],
|
||||||
|
) -> tuple[bool | None, bool | None]:
|
||||||
|
"""Compute the two reliability flags for one field.
|
||||||
|
|
||||||
|
Returns ``(provenance_verified, text_agreement)``. See spec §6 for the
|
||||||
|
dispatch rules. ``None`` on either slot means the check was skipped
|
||||||
|
(Literal, None value, or short-value for text_agreement).
|
||||||
|
"""
|
||||||
|
value = field_provenance.value
|
||||||
|
unwrapped = _unwrap_optional(field_type)
|
||||||
|
|
||||||
|
# Skip Literal / None value entirely — both flags None.
|
||||||
|
if _is_literal(unwrapped) or value is None:
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
|
# provenance_verified: scan cited segments.
|
||||||
|
provenance_verified: bool | None
|
||||||
|
if not field_provenance.sources:
|
||||||
|
provenance_verified = False
|
||||||
|
else:
|
||||||
|
provenance_verified = any(
|
||||||
|
_compare_for_type(value, field_type, s.text_snippet, field_provenance.field_path)
|
||||||
|
for s in field_provenance.sources
|
||||||
|
)
|
||||||
|
|
||||||
|
# text_agreement: None if no texts, else apply short-value rule.
|
||||||
|
text_agreement: bool | None
|
||||||
|
if not texts or should_skip_text_agreement(value, field_type):
|
||||||
|
text_agreement = None
|
||||||
|
else:
|
||||||
|
concatenated = "\n".join(texts)
|
||||||
|
text_agreement = _compare_for_type(
|
||||||
|
value, field_type, concatenated, field_provenance.field_path
|
||||||
|
)
|
||||||
|
|
||||||
|
return provenance_verified, text_agreement
|
||||||
|
|
||||||
|
|
||||||
|
def apply_reliability_flags(
|
||||||
|
provenance_data: ProvenanceData,
|
||||||
|
use_case_response: type[BaseModel],
|
||||||
|
texts: list[str],
|
||||||
|
) -> None:
|
||||||
|
"""Apply :func:`verify_field` to every field in ``provenance_data``.
|
||||||
|
|
||||||
|
Mutates ``provenance_data`` in place:
|
||||||
|
|
||||||
|
* Each ``FieldProvenance``'s ``provenance_verified`` and
|
||||||
|
``text_agreement`` slots are filled.
|
||||||
|
* ``quality_metrics['verified_fields']`` is set to the number of fields
|
||||||
|
whose ``provenance_verified`` is True.
|
||||||
|
* ``quality_metrics['text_agreement_fields']`` is set likewise for
|
||||||
|
``text_agreement``.
|
||||||
|
|
||||||
|
``use_case_response`` is the Pydantic class for the extraction schema
|
||||||
|
(e.g. :class:`~ix.use_cases.bank_statement_header.BankStatementHeader`).
|
||||||
|
Type hints are resolved via ``get_type_hints`` so forward-refs and
|
||||||
|
``str | None`` unions are normalised consistently.
|
||||||
|
"""
|
||||||
|
type_hints = get_type_hints(use_case_response)
|
||||||
|
|
||||||
|
verified_count = 0
|
||||||
|
text_agreement_count = 0
|
||||||
|
for fp in provenance_data.fields.values():
|
||||||
|
# Field path is something like "result.bank_name" — the part after
|
||||||
|
# the first dot is the attribute name on the response schema.
|
||||||
|
leaf = fp.field_path.split(".", 1)[-1]
|
||||||
|
# For nested shapes we only resolve the top-level name; MVP use cases
|
||||||
|
# are flat so that's enough. When we ship nested schemas we'll walk
|
||||||
|
# the annotation tree here.
|
||||||
|
top_attr = leaf.split(".")[0]
|
||||||
|
field_type: Any = type_hints.get(top_attr, str)
|
||||||
|
|
||||||
|
pv, ta = verify_field(fp, field_type, texts)
|
||||||
|
fp.provenance_verified = pv
|
||||||
|
fp.text_agreement = ta
|
||||||
|
if pv is True:
|
||||||
|
verified_count += 1
|
||||||
|
if ta is True:
|
||||||
|
text_agreement_count += 1
|
||||||
|
|
||||||
|
provenance_data.quality_metrics["verified_fields"] = verified_count
|
||||||
|
provenance_data.quality_metrics["text_agreement_fields"] = text_agreement_count
|
||||||
206
tests/unit/test_provenance_mapper.py
Normal file
206
tests/unit/test_provenance_mapper.py
Normal file
|
|
@ -0,0 +1,206 @@
|
||||||
|
"""Tests for the provenance mapper (spec §9.4)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from ix.contracts import (
|
||||||
|
BoundingBox,
|
||||||
|
Line,
|
||||||
|
OCRDetails,
|
||||||
|
OCRResult,
|
||||||
|
Page,
|
||||||
|
SegmentCitation,
|
||||||
|
)
|
||||||
|
from ix.provenance.mapper import (
|
||||||
|
map_segment_refs_to_provenance,
|
||||||
|
resolve_nested_path,
|
||||||
|
)
|
||||||
|
from ix.segmentation import PageMetadata, SegmentIndex
|
||||||
|
|
||||||
|
|
||||||
|
def _make_index_with_lines(lines: list[tuple[str, int]]) -> SegmentIndex:
|
||||||
|
"""Build a tiny index where each line has a known text + file_index.
|
||||||
|
|
||||||
|
Each entry is (text, file_index); all entries go on a single page.
|
||||||
|
"""
|
||||||
|
ocr_lines = [Line(text=t, bounding_box=[0, 0, 10, 0, 10, 5, 0, 5]) for t, _ in lines]
|
||||||
|
page = Page(page_no=1, width=100.0, height=200.0, lines=ocr_lines)
|
||||||
|
ocr = OCRResult(result=OCRDetails(pages=[page]))
|
||||||
|
# file_index for the whole page — the test uses a single page.
|
||||||
|
file_index = lines[0][1] if lines else 0
|
||||||
|
return SegmentIndex.build(
|
||||||
|
ocr_result=ocr,
|
||||||
|
granularity="line",
|
||||||
|
pages_metadata=[PageMetadata(file_index=file_index)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveNestedPath:
|
||||||
|
def test_simple_path(self) -> None:
|
||||||
|
assert resolve_nested_path({"result": {"a": "x"}}, "result.a") == "x"
|
||||||
|
|
||||||
|
def test_nested_path(self) -> None:
|
||||||
|
data = {"result": {"header": {"bank": "UBS"}}}
|
||||||
|
assert resolve_nested_path(data, "result.header.bank") == "UBS"
|
||||||
|
|
||||||
|
def test_missing_path_returns_none(self) -> None:
|
||||||
|
assert resolve_nested_path({"result": {}}, "result.nope") is None
|
||||||
|
|
||||||
|
def test_array_bracket_notation_normalised(self) -> None:
|
||||||
|
data = {"result": {"items": [{"name": "a"}, {"name": "b"}]}}
|
||||||
|
assert resolve_nested_path(data, "result.items[0].name") == "a"
|
||||||
|
assert resolve_nested_path(data, "result.items[1].name") == "b"
|
||||||
|
|
||||||
|
def test_array_dot_notation(self) -> None:
|
||||||
|
data = {"result": {"items": [{"name": "a"}, {"name": "b"}]}}
|
||||||
|
assert resolve_nested_path(data, "result.items.0.name") == "a"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMapper:
|
||||||
|
def test_simple_single_field(self) -> None:
|
||||||
|
idx = _make_index_with_lines([("UBS AG", 0), ("Header text", 0)])
|
||||||
|
extraction = {"result": {"bank_name": "UBS AG"}}
|
||||||
|
citations = [
|
||||||
|
SegmentCitation(field_path="result.bank_name", value_segment_ids=["p1_l0"])
|
||||||
|
]
|
||||||
|
|
||||||
|
prov = map_segment_refs_to_provenance(
|
||||||
|
extraction_result=extraction,
|
||||||
|
segment_citations=citations,
|
||||||
|
segment_index=idx,
|
||||||
|
max_sources_per_field=10,
|
||||||
|
min_confidence=0.0,
|
||||||
|
include_bounding_boxes=True,
|
||||||
|
source_type="value_and_context",
|
||||||
|
)
|
||||||
|
|
||||||
|
fp = prov.fields["result.bank_name"]
|
||||||
|
assert fp.field_name == "bank_name"
|
||||||
|
assert fp.value == "UBS AG"
|
||||||
|
assert len(fp.sources) == 1
|
||||||
|
src = fp.sources[0]
|
||||||
|
assert src.segment_id == "p1_l0"
|
||||||
|
assert src.text_snippet == "UBS AG"
|
||||||
|
assert src.page_number == 1
|
||||||
|
assert src.file_index == 0
|
||||||
|
assert isinstance(src.bounding_box, BoundingBox)
|
||||||
|
# quality_metrics populated
|
||||||
|
assert prov.quality_metrics["invalid_references"] == 0
|
||||||
|
|
||||||
|
def test_invalid_reference_counted(self) -> None:
|
||||||
|
idx = _make_index_with_lines([("UBS AG", 0)])
|
||||||
|
extraction = {"result": {"bank_name": "UBS AG"}}
|
||||||
|
citations = [
|
||||||
|
SegmentCitation(
|
||||||
|
field_path="result.bank_name",
|
||||||
|
value_segment_ids=["p1_l0", "p9_l9"], # p9_l9 doesn't exist
|
||||||
|
)
|
||||||
|
]
|
||||||
|
prov = map_segment_refs_to_provenance(
|
||||||
|
extraction_result=extraction,
|
||||||
|
segment_citations=citations,
|
||||||
|
segment_index=idx,
|
||||||
|
max_sources_per_field=10,
|
||||||
|
min_confidence=0.0,
|
||||||
|
include_bounding_boxes=True,
|
||||||
|
source_type="value_and_context",
|
||||||
|
)
|
||||||
|
assert prov.quality_metrics["invalid_references"] == 1
|
||||||
|
# The one valid source still populated.
|
||||||
|
assert len(prov.fields["result.bank_name"].sources) == 1
|
||||||
|
|
||||||
|
def test_max_sources_cap(self) -> None:
|
||||||
|
# Five lines; ask for a cap of 2.
|
||||||
|
idx = _make_index_with_lines([(f"line {i}", 0) for i in range(5)])
|
||||||
|
citations = [
|
||||||
|
SegmentCitation(
|
||||||
|
field_path="result.notes",
|
||||||
|
value_segment_ids=[f"p1_l{i}" for i in range(5)],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
prov = map_segment_refs_to_provenance(
|
||||||
|
extraction_result={"result": {"notes": "noise"}},
|
||||||
|
segment_citations=citations,
|
||||||
|
segment_index=idx,
|
||||||
|
max_sources_per_field=2,
|
||||||
|
min_confidence=0.0,
|
||||||
|
include_bounding_boxes=True,
|
||||||
|
source_type="value_and_context",
|
||||||
|
)
|
||||||
|
assert len(prov.fields["result.notes"].sources) == 2
|
||||||
|
|
||||||
|
def test_source_type_value_only(self) -> None:
|
||||||
|
idx = _make_index_with_lines([("label:", 0), ("UBS AG", 0)])
|
||||||
|
citations = [
|
||||||
|
SegmentCitation(
|
||||||
|
field_path="result.bank_name",
|
||||||
|
value_segment_ids=["p1_l1"],
|
||||||
|
context_segment_ids=["p1_l0"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
prov = map_segment_refs_to_provenance(
|
||||||
|
extraction_result={"result": {"bank_name": "UBS AG"}},
|
||||||
|
segment_citations=citations,
|
||||||
|
segment_index=idx,
|
||||||
|
max_sources_per_field=10,
|
||||||
|
min_confidence=0.0,
|
||||||
|
include_bounding_boxes=True,
|
||||||
|
source_type="value",
|
||||||
|
)
|
||||||
|
sources = prov.fields["result.bank_name"].sources
|
||||||
|
# Only value_segment_ids included.
|
||||||
|
assert [s.segment_id for s in sources] == ["p1_l1"]
|
||||||
|
|
||||||
|
def test_source_type_value_and_context(self) -> None:
|
||||||
|
idx = _make_index_with_lines([("label:", 0), ("UBS AG", 0)])
|
||||||
|
citations = [
|
||||||
|
SegmentCitation(
|
||||||
|
field_path="result.bank_name",
|
||||||
|
value_segment_ids=["p1_l1"],
|
||||||
|
context_segment_ids=["p1_l0"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
prov = map_segment_refs_to_provenance(
|
||||||
|
extraction_result={"result": {"bank_name": "UBS AG"}},
|
||||||
|
segment_citations=citations,
|
||||||
|
segment_index=idx,
|
||||||
|
max_sources_per_field=10,
|
||||||
|
min_confidence=0.0,
|
||||||
|
include_bounding_boxes=True,
|
||||||
|
source_type="value_and_context",
|
||||||
|
)
|
||||||
|
sources = prov.fields["result.bank_name"].sources
|
||||||
|
assert [s.segment_id for s in sources] == ["p1_l1", "p1_l0"]
|
||||||
|
|
||||||
|
def test_include_bounding_boxes_false(self) -> None:
|
||||||
|
idx = _make_index_with_lines([("UBS AG", 0)])
|
||||||
|
citations = [
|
||||||
|
SegmentCitation(field_path="result.bank_name", value_segment_ids=["p1_l0"])
|
||||||
|
]
|
||||||
|
prov = map_segment_refs_to_provenance(
|
||||||
|
extraction_result={"result": {"bank_name": "UBS AG"}},
|
||||||
|
segment_citations=citations,
|
||||||
|
segment_index=idx,
|
||||||
|
max_sources_per_field=10,
|
||||||
|
min_confidence=0.0,
|
||||||
|
include_bounding_boxes=False,
|
||||||
|
source_type="value_and_context",
|
||||||
|
)
|
||||||
|
assert prov.fields["result.bank_name"].sources[0].bounding_box is None
|
||||||
|
|
||||||
|
def test_field_with_no_valid_sources_skipped(self) -> None:
|
||||||
|
idx = _make_index_with_lines([("UBS", 0)])
|
||||||
|
citations = [
|
||||||
|
SegmentCitation(field_path="result.ghost", value_segment_ids=["p9_l9"])
|
||||||
|
]
|
||||||
|
prov = map_segment_refs_to_provenance(
|
||||||
|
extraction_result={"result": {"ghost": "x"}},
|
||||||
|
segment_citations=citations,
|
||||||
|
segment_index=idx,
|
||||||
|
max_sources_per_field=10,
|
||||||
|
min_confidence=0.0,
|
||||||
|
include_bounding_boxes=True,
|
||||||
|
source_type="value_and_context",
|
||||||
|
)
|
||||||
|
# Field not added when zero valid sources (spec §9.4 step).
|
||||||
|
assert "result.ghost" not in prov.fields
|
||||||
|
assert prov.quality_metrics["invalid_references"] == 1
|
||||||
220
tests/unit/test_provenance_verify.py
Normal file
220
tests/unit/test_provenance_verify.py
Normal file
|
|
@ -0,0 +1,220 @@
|
||||||
|
"""Tests for the reliability verifier (spec §6 ReliabilityStep)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
from decimal import Decimal
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ix.contracts import (
|
||||||
|
ExtractionSource,
|
||||||
|
FieldProvenance,
|
||||||
|
ProvenanceData,
|
||||||
|
)
|
||||||
|
from ix.provenance.verify import apply_reliability_flags, verify_field
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fp(
|
||||||
|
*,
|
||||||
|
field_path: str,
|
||||||
|
value: object,
|
||||||
|
snippets: list[str],
|
||||||
|
) -> FieldProvenance:
|
||||||
|
return FieldProvenance(
|
||||||
|
field_name=field_path.split(".")[-1],
|
||||||
|
field_path=field_path,
|
||||||
|
value=value,
|
||||||
|
sources=[
|
||||||
|
ExtractionSource(
|
||||||
|
page_number=1,
|
||||||
|
file_index=0,
|
||||||
|
text_snippet=s,
|
||||||
|
relevance_score=1.0,
|
||||||
|
segment_id=f"p1_l{i}",
|
||||||
|
)
|
||||||
|
for i, s in enumerate(snippets)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestVerifyFieldByType:
|
||||||
|
def test_string_substring_match(self) -> None:
|
||||||
|
fp = _make_fp(
|
||||||
|
field_path="result.bank_name",
|
||||||
|
value="UBS AG",
|
||||||
|
snippets=["Account at UBS AG, Zurich"],
|
||||||
|
)
|
||||||
|
pv, ta = verify_field(fp, str, texts=[])
|
||||||
|
assert pv is True
|
||||||
|
assert ta is None
|
||||||
|
|
||||||
|
def test_string_mismatch_is_false(self) -> None:
|
||||||
|
fp = _make_fp(
|
||||||
|
field_path="result.bank_name",
|
||||||
|
value="UBS AG",
|
||||||
|
snippets=["Credit Suisse"],
|
||||||
|
)
|
||||||
|
pv, _ = verify_field(fp, str, texts=[])
|
||||||
|
assert pv is False
|
||||||
|
|
||||||
|
def test_number_decimal_match_ignores_currency(self) -> None:
|
||||||
|
fp = _make_fp(
|
||||||
|
field_path="result.closing_balance",
|
||||||
|
value=Decimal("1234.56"),
|
||||||
|
snippets=["CHF 1'234.56"],
|
||||||
|
)
|
||||||
|
pv, _ = verify_field(fp, Decimal, texts=[])
|
||||||
|
assert pv is True
|
||||||
|
|
||||||
|
def test_number_mismatch(self) -> None:
|
||||||
|
fp = _make_fp(
|
||||||
|
field_path="result.closing_balance",
|
||||||
|
value=Decimal("1234.56"),
|
||||||
|
snippets=["CHF 9999.99"],
|
||||||
|
)
|
||||||
|
pv, _ = verify_field(fp, Decimal, texts=[])
|
||||||
|
assert pv is False
|
||||||
|
|
||||||
|
def test_date_parse_both_sides(self) -> None:
|
||||||
|
fp = _make_fp(
|
||||||
|
field_path="result.statement_date",
|
||||||
|
value=date(2026, 3, 31),
|
||||||
|
snippets=["Statement date: 31.03.2026"],
|
||||||
|
)
|
||||||
|
pv, _ = verify_field(fp, date, texts=[])
|
||||||
|
assert pv is True
|
||||||
|
|
||||||
|
def test_iban_strip_and_case(self) -> None:
|
||||||
|
# IBAN detection: field name contains "iban".
|
||||||
|
fp = _make_fp(
|
||||||
|
field_path="result.account_iban",
|
||||||
|
value="CH9300762011623852957",
|
||||||
|
snippets=["Account CH93 0076 2011 6238 5295 7"],
|
||||||
|
)
|
||||||
|
pv, _ = verify_field(fp, str, texts=[])
|
||||||
|
assert pv is True
|
||||||
|
|
||||||
|
def test_literal_field_both_flags_none(self) -> None:
|
||||||
|
fp = _make_fp(
|
||||||
|
field_path="result.account_type",
|
||||||
|
value="checking",
|
||||||
|
snippets=["the word checking is literally here"],
|
||||||
|
)
|
||||||
|
pv, ta = verify_field(fp, Literal["checking", "credit", "savings"], texts=["checking"])
|
||||||
|
assert pv is None
|
||||||
|
assert ta is None
|
||||||
|
|
||||||
|
def test_none_value_both_flags_none(self) -> None:
|
||||||
|
fp = _make_fp(
|
||||||
|
field_path="result.account_iban",
|
||||||
|
value=None,
|
||||||
|
snippets=["whatever"],
|
||||||
|
)
|
||||||
|
pv, ta = verify_field(fp, str, texts=["whatever"])
|
||||||
|
assert pv is None
|
||||||
|
assert ta is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestTextAgreement:
|
||||||
|
def test_text_agreement_with_texts_true(self) -> None:
|
||||||
|
fp = _make_fp(
|
||||||
|
field_path="result.bank_name",
|
||||||
|
value="UBS AG",
|
||||||
|
snippets=["UBS AG"],
|
||||||
|
)
|
||||||
|
_, ta = verify_field(fp, str, texts=["Account at UBS AG"])
|
||||||
|
assert ta is True
|
||||||
|
|
||||||
|
def test_text_agreement_with_texts_false(self) -> None:
|
||||||
|
fp = _make_fp(
|
||||||
|
field_path="result.bank_name",
|
||||||
|
value="UBS AG",
|
||||||
|
snippets=["UBS AG"],
|
||||||
|
)
|
||||||
|
_, ta = verify_field(fp, str, texts=["Credit Suisse"])
|
||||||
|
assert ta is False
|
||||||
|
|
||||||
|
def test_text_agreement_no_texts_is_none(self) -> None:
|
||||||
|
fp = _make_fp(
|
||||||
|
field_path="result.bank_name",
|
||||||
|
value="UBS AG",
|
||||||
|
snippets=["UBS AG"],
|
||||||
|
)
|
||||||
|
_, ta = verify_field(fp, str, texts=[])
|
||||||
|
assert ta is None
|
||||||
|
|
||||||
|
def test_short_value_skips_text_agreement(self) -> None:
|
||||||
|
# 2-char string
|
||||||
|
fp = _make_fp(
|
||||||
|
field_path="result.code",
|
||||||
|
value="XY",
|
||||||
|
snippets=["code XY here"],
|
||||||
|
)
|
||||||
|
pv, ta = verify_field(fp, str, texts=["another XY reference"])
|
||||||
|
# provenance_verified still runs; text_agreement is skipped.
|
||||||
|
assert pv is True
|
||||||
|
assert ta is None
|
||||||
|
|
||||||
|
def test_small_number_skips_text_agreement(self) -> None:
|
||||||
|
fp = _make_fp(
|
||||||
|
field_path="result.n",
|
||||||
|
value=5,
|
||||||
|
snippets=["value 5 here"],
|
||||||
|
)
|
||||||
|
pv, ta = verify_field(fp, int, texts=["the number 5"])
|
||||||
|
assert pv is True
|
||||||
|
assert ta is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestApplyReliabilityFlags:
|
||||||
|
def test_writes_flags_and_counters(self) -> None:
|
||||||
|
class BankHeader(BaseModel):
|
||||||
|
bank_name: str
|
||||||
|
account_iban: str | None = None
|
||||||
|
closing_balance: Decimal | None = None
|
||||||
|
account_type: Literal["checking", "credit", "savings"] | None = None
|
||||||
|
|
||||||
|
prov = ProvenanceData(
|
||||||
|
fields={
|
||||||
|
"result.bank_name": _make_fp(
|
||||||
|
field_path="result.bank_name",
|
||||||
|
value="UBS AG",
|
||||||
|
snippets=["Account at UBS AG"],
|
||||||
|
),
|
||||||
|
"result.account_iban": _make_fp(
|
||||||
|
field_path="result.account_iban",
|
||||||
|
value="CH9300762011623852957",
|
||||||
|
snippets=["IBAN CH93 0076 2011 6238 5295 7"],
|
||||||
|
),
|
||||||
|
"result.closing_balance": _make_fp(
|
||||||
|
field_path="result.closing_balance",
|
||||||
|
value=Decimal("1234.56"),
|
||||||
|
snippets=["Closing balance CHF 1'234.56"],
|
||||||
|
),
|
||||||
|
"result.account_type": _make_fp(
|
||||||
|
field_path="result.account_type",
|
||||||
|
value="checking",
|
||||||
|
snippets=["current account (checking)"],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
apply_reliability_flags(prov, BankHeader, texts=["Account at UBS AG at CH9300762011623852957"])
|
||||||
|
|
||||||
|
fields = prov.fields
|
||||||
|
assert fields["result.bank_name"].provenance_verified is True
|
||||||
|
assert fields["result.bank_name"].text_agreement is True
|
||||||
|
assert fields["result.account_iban"].provenance_verified is True
|
||||||
|
assert fields["result.closing_balance"].provenance_verified is True
|
||||||
|
# account_type is Literal → both flags None.
|
||||||
|
assert fields["result.account_type"].provenance_verified is None
|
||||||
|
assert fields["result.account_type"].text_agreement is None
|
||||||
|
|
||||||
|
# Counters record only True values.
|
||||||
|
qm = prov.quality_metrics
|
||||||
|
assert qm["verified_fields"] == 3 # all except Literal
|
||||||
|
# text_agreement_fields counts only fields where the flag is True.
|
||||||
|
# bank_name True; IBAN True (appears in texts after normalisation);
|
||||||
|
# closing_balance -- '1234.56' doesn't appear in the text.
|
||||||
|
assert qm["text_agreement_fields"] >= 1
|
||||||
Loading…
Reference in a new issue