"""Reliability verifier — writes `provenance_verified` + `text_agreement`. Implements the dispatch table from spec §6 ReliabilityStep. The normalisers in :mod:`ix.provenance.normalize` do the actual string/number/date work; this module chooses which one to run for each field based on its Pydantic annotation, and writes the two reliability flags onto the existing :class:`FieldProvenance` records in place. The summary counters (``verified_fields``, ``text_agreement_fields``) land in ``provenance.quality_metrics`` alongside the existing coverage metrics. """ from __future__ import annotations import re import types from datetime import date, datetime from decimal import Decimal, InvalidOperation from typing import Any, Literal, Union, get_args, get_origin, get_type_hints from pydantic import BaseModel from ix.contracts.provenance import FieldProvenance, ProvenanceData from ix.provenance.normalize import ( normalize_date, normalize_iban, normalize_number, normalize_string, should_skip_text_agreement, ) def _unwrap_optional(tp: Any) -> Any: """Return the non-None arm of ``Optional[X]`` / ``X | None``. Handles both ``typing.Union`` and ``types.UnionType`` (PEP 604 ``X | Y`` unions), which ``get_type_hints`` returns on Python 3.12+. """ origin = get_origin(tp) if origin is Union or origin is types.UnionType: args = [a for a in get_args(tp) if a is not type(None)] if len(args) == 1: return args[0] return tp def _is_literal(tp: Any) -> bool: return get_origin(tp) is Literal def _is_iban_field(field_path: str) -> bool: return "iban" in field_path.lower() def _compare_string(value: str, snippet: str) -> bool: try: return normalize_string(value) in normalize_string(snippet) except Exception: return False _NUMBER_TOKEN_RE = re.compile(r"[+\-]?[\d][\d',.\s]*\d|[+\-]?\d+") def _compare_number(value: Any, snippet: str) -> bool: try: canonical = normalize_number(value) except (InvalidOperation, ValueError): return False # Try the whole snippet first (cheap path when the snippet IS the number). try: if normalize_number(snippet) == canonical: return True except (InvalidOperation, ValueError): pass # Fall back to scanning numeric substrings — OCR snippets commonly carry # labels ("Closing balance CHF 1'234.56") that confuse a whole-string # numeric parse. for match in _NUMBER_TOKEN_RE.finditer(snippet): token = match.group() try: if normalize_number(token) == canonical: return True except (InvalidOperation, ValueError): continue return False def _compare_date(value: Any, snippet: str) -> bool: try: iso_value = normalize_date(value) except (ValueError, TypeError): return False # Find any date-like chunk in snippet; try normalising each token segment. # Simplest heuristic: try snippet as a whole; on failure, scan tokens. try: if normalize_date(snippet) == iso_value: return True except (ValueError, TypeError): pass # Token scan — dateutil will raise on the non-date tokens, which is fine. for token in _tokenise_for_date(snippet): try: if normalize_date(token) == iso_value: return True except (ValueError, TypeError): continue return False def _tokenise_for_date(s: str) -> list[str]: """Split on whitespace + common punctuation so date strings survive whole. Keeps dots / slashes / dashes inside tokens (they're valid date separators); splits on spaces, commas, semicolons, colons, brackets. """ return [t for t in re.split(r"[\s,;:()\[\]]+", s) if t] def _compare_iban(value: str, snippet: str) -> bool: try: return normalize_iban(value) in normalize_iban(snippet) except Exception: return False def _compare_for_type(value: Any, field_type: Any, snippet: str, field_path: str) -> bool: unwrapped = _unwrap_optional(field_type) # Date / datetime. if unwrapped in (date, datetime): return _compare_date(value, snippet) # Numeric. if unwrapped in (int, float, Decimal): return _compare_number(value, snippet) # IBAN (detected by field name). if _is_iban_field(field_path): return _compare_iban(str(value), snippet) # Default: string substring. return _compare_string(str(value), snippet) def verify_field( field_provenance: FieldProvenance, field_type: Any, texts: list[str], ) -> tuple[bool | None, bool | None]: """Compute the two reliability flags for one field. Returns ``(provenance_verified, text_agreement)``. See spec §6 for the dispatch rules. ``None`` on either slot means the check was skipped (Literal, None value, or short-value for text_agreement). """ value = field_provenance.value unwrapped = _unwrap_optional(field_type) # Skip Literal / None value entirely — both flags None. if _is_literal(unwrapped) or value is None: return (None, None) # provenance_verified: scan cited segments. provenance_verified: bool | None if not field_provenance.sources: provenance_verified = False else: provenance_verified = any( _compare_for_type(value, field_type, s.text_snippet, field_provenance.field_path) for s in field_provenance.sources ) # text_agreement: None if no texts, else apply short-value rule. text_agreement: bool | None if not texts or should_skip_text_agreement(value, field_type): text_agreement = None else: concatenated = "\n".join(texts) text_agreement = _compare_for_type( value, field_type, concatenated, field_provenance.field_path ) return provenance_verified, text_agreement def apply_reliability_flags( provenance_data: ProvenanceData, use_case_response: type[BaseModel], texts: list[str], ) -> None: """Apply :func:`verify_field` to every field in ``provenance_data``. Mutates ``provenance_data`` in place: * Each ``FieldProvenance``'s ``provenance_verified`` and ``text_agreement`` slots are filled. * ``quality_metrics['verified_fields']`` is set to the number of fields whose ``provenance_verified`` is True. * ``quality_metrics['text_agreement_fields']`` is set likewise for ``text_agreement``. ``use_case_response`` is the Pydantic class for the extraction schema (e.g. :class:`~ix.use_cases.bank_statement_header.BankStatementHeader`). Type hints are resolved via ``get_type_hints`` so forward-refs and ``str | None`` unions are normalised consistently. """ type_hints = get_type_hints(use_case_response) verified_count = 0 text_agreement_count = 0 for fp in provenance_data.fields.values(): # Field path is something like "result.bank_name" — the part after # the first dot is the attribute name on the response schema. leaf = fp.field_path.split(".", 1)[-1] # For nested shapes we only resolve the top-level name; MVP use cases # are flat so that's enough. When we ship nested schemas we'll walk # the annotation tree here. top_attr = leaf.split(".")[0] field_type: Any = type_hints.get(top_attr, str) pv, ta = verify_field(fp, field_type, texts) fp.provenance_verified = pv fp.text_agreement = ta if pv is True: verified_count += 1 if ta is True: text_agreement_count += 1 provenance_data.quality_metrics["verified_fields"] = verified_count provenance_data.quality_metrics["text_agreement_fields"] = text_agreement_count