infoxtractor/src/ix/provenance/verify.py
Dirk Riemann 1e340c82fa
All checks were successful
tests / test (pull_request) Successful in 1m10s
tests / test (push) Successful in 1m11s
feat(provenance): mapper + verifier for ReliabilityStep (spec §9.4, §6)
Lands the two remaining provenance-subsystem pieces:

mapper.py — map_segment_refs_to_provenance:
- For each LLM SegmentCitation, pick seg-ids per source_type
  (`value` vs `value_and_context`), cap at max_sources_per_field,
  resolve each via SegmentIndex, track invalid references.
- Resolve field values by dot-path (`result.items[0].name` supported —
  `[N]` bracket notation is normalised to `.N` before traversal).
- Skip fields that resolve to zero valid sources (spec §9.4).
- Write quality_metrics with fields_with_provenance / total_fields /
  coverage_rate / invalid_references.

verify.py — verify_field + apply_reliability_flags:
- Dispatches per Pydantic field type: date → parse-both-sides compare;
  int/float/Decimal → normalize + whole-snippet / numeric-token scan;
  IBAN (detected via `iban` in field name) → upper+strip compare;
  Literal / None → flags stay None; else string substring.
- _unwrap_optional handles BOTH typing.Union AND types.UnionType so
  `Decimal | None` (PEP 604, what get_type_hints emits on 3.12+) resolves
  correctly — caught by the integration-style test_writes_flags_and_counters.
- Number comparator scans numeric tokens in the snippet so labels
  ("Closing balance CHF 1'234.56") don't mask the match.
- apply_reliability_flags mutates the passed ProvenanceData in place and
  writes verified_fields / text_agreement_fields to quality_metrics.

Tests cover each comparator, Literal/None skip, short-value skip (strings
and numerics), Decimal via optional union, and end-to-end flag+counter
writing against a Pydantic use-case schema that mirrors bank_statement_header.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-18 11:01:19 +02:00

231 lines
7.7 KiB
Python

"""Reliability verifier — writes `provenance_verified` + `text_agreement`.
Implements the dispatch table from spec §6 ReliabilityStep. The normalisers
in :mod:`ix.provenance.normalize` do the actual string/number/date work; this
module chooses which one to run for each field based on its Pydantic
annotation, and writes the two reliability flags onto the existing
:class:`FieldProvenance` records in place.
The summary counters (``verified_fields``, ``text_agreement_fields``) land in
``provenance.quality_metrics`` alongside the existing coverage metrics.
"""
from __future__ import annotations
import re
import types
from datetime import date, datetime
from decimal import Decimal, InvalidOperation
from typing import Any, Literal, Union, get_args, get_origin, get_type_hints
from pydantic import BaseModel
from ix.contracts.provenance import FieldProvenance, ProvenanceData
from ix.provenance.normalize import (
normalize_date,
normalize_iban,
normalize_number,
normalize_string,
should_skip_text_agreement,
)
def _unwrap_optional(tp: Any) -> Any:
"""Return the non-None arm of ``Optional[X]`` / ``X | None``.
Handles both ``typing.Union`` and ``types.UnionType`` (PEP 604 ``X | Y``
unions), which ``get_type_hints`` returns on Python 3.12+.
"""
origin = get_origin(tp)
if origin is Union or origin is types.UnionType:
args = [a for a in get_args(tp) if a is not type(None)]
if len(args) == 1:
return args[0]
return tp
def _is_literal(tp: Any) -> bool:
return get_origin(tp) is Literal
def _is_iban_field(field_path: str) -> bool:
return "iban" in field_path.lower()
def _compare_string(value: str, snippet: str) -> bool:
try:
return normalize_string(value) in normalize_string(snippet)
except Exception:
return False
_NUMBER_TOKEN_RE = re.compile(r"[+\-]?[\d][\d',.\s]*\d|[+\-]?\d+")
def _compare_number(value: Any, snippet: str) -> bool:
try:
canonical = normalize_number(value)
except (InvalidOperation, ValueError):
return False
# Try the whole snippet first (cheap path when the snippet IS the number).
try:
if normalize_number(snippet) == canonical:
return True
except (InvalidOperation, ValueError):
pass
# Fall back to scanning numeric substrings — OCR snippets commonly carry
# labels ("Closing balance CHF 1'234.56") that confuse a whole-string
# numeric parse.
for match in _NUMBER_TOKEN_RE.finditer(snippet):
token = match.group()
try:
if normalize_number(token) == canonical:
return True
except (InvalidOperation, ValueError):
continue
return False
def _compare_date(value: Any, snippet: str) -> bool:
try:
iso_value = normalize_date(value)
except (ValueError, TypeError):
return False
# Find any date-like chunk in snippet; try normalising each token segment.
# Simplest heuristic: try snippet as a whole; on failure, scan tokens.
try:
if normalize_date(snippet) == iso_value:
return True
except (ValueError, TypeError):
pass
# Token scan — dateutil will raise on the non-date tokens, which is fine.
for token in _tokenise_for_date(snippet):
try:
if normalize_date(token) == iso_value:
return True
except (ValueError, TypeError):
continue
return False
def _tokenise_for_date(s: str) -> list[str]:
"""Split on whitespace + common punctuation so date strings survive whole.
Keeps dots / slashes / dashes inside tokens (they're valid date
separators); splits on spaces, commas, semicolons, colons, brackets.
"""
return [t for t in re.split(r"[\s,;:()\[\]]+", s) if t]
def _compare_iban(value: str, snippet: str) -> bool:
try:
return normalize_iban(value) in normalize_iban(snippet)
except Exception:
return False
def _compare_for_type(value: Any, field_type: Any, snippet: str, field_path: str) -> bool:
unwrapped = _unwrap_optional(field_type)
# Date / datetime.
if unwrapped in (date, datetime):
return _compare_date(value, snippet)
# Numeric.
if unwrapped in (int, float, Decimal):
return _compare_number(value, snippet)
# IBAN (detected by field name).
if _is_iban_field(field_path):
return _compare_iban(str(value), snippet)
# Default: string substring.
return _compare_string(str(value), snippet)
def verify_field(
field_provenance: FieldProvenance,
field_type: Any,
texts: list[str],
) -> tuple[bool | None, bool | None]:
"""Compute the two reliability flags for one field.
Returns ``(provenance_verified, text_agreement)``. See spec §6 for the
dispatch rules. ``None`` on either slot means the check was skipped
(Literal, None value, or short-value for text_agreement).
"""
value = field_provenance.value
unwrapped = _unwrap_optional(field_type)
# Skip Literal / None value entirely — both flags None.
if _is_literal(unwrapped) or value is None:
return (None, None)
# provenance_verified: scan cited segments.
provenance_verified: bool | None
if not field_provenance.sources:
provenance_verified = False
else:
provenance_verified = any(
_compare_for_type(value, field_type, s.text_snippet, field_provenance.field_path)
for s in field_provenance.sources
)
# text_agreement: None if no texts, else apply short-value rule.
text_agreement: bool | None
if not texts or should_skip_text_agreement(value, field_type):
text_agreement = None
else:
concatenated = "\n".join(texts)
text_agreement = _compare_for_type(
value, field_type, concatenated, field_provenance.field_path
)
return provenance_verified, text_agreement
def apply_reliability_flags(
provenance_data: ProvenanceData,
use_case_response: type[BaseModel],
texts: list[str],
) -> None:
"""Apply :func:`verify_field` to every field in ``provenance_data``.
Mutates ``provenance_data`` in place:
* Each ``FieldProvenance``'s ``provenance_verified`` and
``text_agreement`` slots are filled.
* ``quality_metrics['verified_fields']`` is set to the number of fields
whose ``provenance_verified`` is True.
* ``quality_metrics['text_agreement_fields']`` is set likewise for
``text_agreement``.
``use_case_response`` is the Pydantic class for the extraction schema
(e.g. :class:`~ix.use_cases.bank_statement_header.BankStatementHeader`).
Type hints are resolved via ``get_type_hints`` so forward-refs and
``str | None`` unions are normalised consistently.
"""
type_hints = get_type_hints(use_case_response)
verified_count = 0
text_agreement_count = 0
for fp in provenance_data.fields.values():
# Field path is something like "result.bank_name" — the part after
# the first dot is the attribute name on the response schema.
leaf = fp.field_path.split(".", 1)[-1]
# For nested shapes we only resolve the top-level name; MVP use cases
# are flat so that's enough. When we ship nested schemas we'll walk
# the annotation tree here.
top_attr = leaf.split(".")[0]
field_type: Any = type_hints.get(top_attr, str)
pv, ta = verify_field(fp, field_type, texts)
fp.provenance_verified = pv
fp.text_agreement = ta
if pv is True:
verified_count += 1
if ta is True:
text_agreement_count += 1
provenance_data.quality_metrics["verified_fields"] = verified_count
provenance_data.quality_metrics["text_agreement_fields"] = text_agreement_count