From 132f110463052594c61fb8c699d4bfa14842bde7 Mon Sep 17 00:00:00 2001 From: Dirk Riemann Date: Sat, 18 Apr 2026 11:20:18 +0200 Subject: [PATCH] =?UTF-8?q?feat(pipeline):=20ReliabilityStep=20=E2=80=94?= =?UTF-8?q?=20writes=20reliability=20flags=20(spec=20=C2=A76)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Thin wrapper around ix.provenance.apply_reliability_flags. Validate skips entirely when include_provenance is off OR when no provenance data was built (text-only request, etc.). Process reads context.texts + context.use_case_response and lets the verifier mutate the FieldProvenance entries + fill quality_metrics counters in place. 11 unit tests in tests/unit/test_reliability_step.py cover: validate skips on flag off / missing provenance, runs otherwise; per-type flag behaviour (string verified + text_agreement, Literal -> None, None value -> None, short numeric -> text_agreement None, date with both sides parsed, IBAN whitespace-insensitive, disagreement -> False); quality_metrics verified_fields / text_agreement_fields counters. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/ix/pipeline/reliability_step.py | 56 +++++++ tests/unit/test_reliability_step.py | 250 ++++++++++++++++++++++++++++ 2 files changed, 306 insertions(+) create mode 100644 src/ix/pipeline/reliability_step.py create mode 100644 tests/unit/test_reliability_step.py diff --git a/src/ix/pipeline/reliability_step.py b/src/ix/pipeline/reliability_step.py new file mode 100644 index 0000000..bc240bd --- /dev/null +++ b/src/ix/pipeline/reliability_step.py @@ -0,0 +1,56 @@ +"""ReliabilityStep — writes provenance_verified + text_agreement (spec §6). + +Runs after :class:`~ix.pipeline.genai_step.GenAIStep`. Skips entirely +when provenance is off OR when no provenance data was built (OCR-skipped +text-only request, for example). Otherwise delegates to +:func:`~ix.provenance.apply_reliability_flags`, which mutates each +:class:`~ix.contracts.FieldProvenance` in place and fills the two +summary counters (``verified_fields``, ``text_agreement_fields``) on +``quality_metrics``. + +No own dispatch logic — everything interesting lives in the normalisers ++ verifier modules and is unit-tested there. +""" + +from __future__ import annotations + +from typing import cast + +from pydantic import BaseModel + +from ix.contracts import RequestIX, ResponseIX +from ix.pipeline.step import Step +from ix.provenance import apply_reliability_flags + + +class ReliabilityStep(Step): + """Fills per-field reliability flags on ``response.provenance``.""" + + async def validate(self, request_ix: RequestIX, response_ix: ResponseIX) -> bool: + if not request_ix.options.provenance.include_provenance: + return False + return response_ix.provenance is not None + + async def process( + self, request_ix: RequestIX, response_ix: ResponseIX + ) -> ResponseIX: + assert response_ix.provenance is not None # validate() guarantees + + ctx = response_ix.context + texts: list[str] = ( + list(getattr(ctx, "texts", []) or []) if ctx is not None else [] + ) + use_case_response_cls = cast( + type[BaseModel], + getattr(ctx, "use_case_response", None) if ctx is not None else None, + ) + + apply_reliability_flags( + provenance_data=response_ix.provenance, + use_case_response=use_case_response_cls, + texts=texts, + ) + return response_ix + + +__all__ = ["ReliabilityStep"] diff --git a/tests/unit/test_reliability_step.py b/tests/unit/test_reliability_step.py new file mode 100644 index 0000000..c15041d --- /dev/null +++ b/tests/unit/test_reliability_step.py @@ -0,0 +1,250 @@ +"""Tests for :class:`ix.pipeline.reliability_step.ReliabilityStep` (spec §6).""" + +from __future__ import annotations + +from datetime import date +from decimal import Decimal + +from ix.contracts import ( + BoundingBox, + Context, + ExtractionSource, + FieldProvenance, + OCROptions, + Options, + ProvenanceData, + ProvenanceOptions, + RequestIX, + ResponseIX, +) +from ix.contracts.response import _InternalContext +from ix.pipeline.reliability_step import ReliabilityStep +from ix.use_cases.bank_statement_header import BankStatementHeader + + +def _src( + segment_id: str, + text: str, + page: int = 1, + bbox: list[float] | None = None, +) -> ExtractionSource: + return ExtractionSource( + page_number=page, + file_index=0, + bounding_box=BoundingBox(coordinates=bbox or [0, 0, 1, 0, 1, 1, 0, 1]), + text_snippet=text, + relevance_score=1.0, + segment_id=segment_id, + ) + + +def _make_request( + include_provenance: bool = True, texts: list[str] | None = None +) -> RequestIX: + return RequestIX( + use_case="bank_statement_header", + ix_client_id="test", + request_id="r-1", + context=Context(files=[], texts=texts or []), + options=Options( + ocr=OCROptions(), + provenance=ProvenanceOptions(include_provenance=include_provenance), + ), + ) + + +def _response_with_provenance( + fields: dict[str, FieldProvenance], + texts: list[str] | None = None, +) -> ResponseIX: + resp = ResponseIX() + resp.provenance = ProvenanceData( + fields=fields, + quality_metrics={}, + segment_count=10, + granularity="line", + ) + resp.context = _InternalContext( + texts=texts or [], + use_case_response=BankStatementHeader, + ) + return resp + + +class TestValidate: + async def test_skipped_when_provenance_off(self) -> None: + step = ReliabilityStep() + req = _make_request(include_provenance=False) + resp = _response_with_provenance(fields={}) + assert await step.validate(req, resp) is False + + async def test_skipped_when_no_provenance_data(self) -> None: + step = ReliabilityStep() + req = _make_request(include_provenance=True) + resp = ResponseIX() + assert await step.validate(req, resp) is False + + async def test_runs_when_provenance_data_present(self) -> None: + step = ReliabilityStep() + req = _make_request(include_provenance=True) + resp = _response_with_provenance(fields={}) + assert await step.validate(req, resp) is True + + +class TestProcessFlags: + async def test_string_field_verified_and_text_agreement(self) -> None: + fp = FieldProvenance( + field_name="bank_name", + field_path="result.bank_name", + value="DKB", + sources=[_src("p1_l0", "DKB")], + ) + resp = _response_with_provenance( + fields={"result.bank_name": fp}, + texts=["DKB statement content"], + ) + step = ReliabilityStep() + resp = await step.process(_make_request(texts=["DKB statement content"]), resp) + out = resp.provenance.fields["result.bank_name"] + assert out.provenance_verified is True + assert out.text_agreement is True + + async def test_literal_field_flags_none(self) -> None: + fp = FieldProvenance( + field_name="account_type", + field_path="result.account_type", + value="checking", + sources=[_src("p1_l0", "anything")], + ) + resp = _response_with_provenance( + fields={"result.account_type": fp}, + texts=["some text"], + ) + step = ReliabilityStep() + resp = await step.process(_make_request(texts=["some text"]), resp) + out = resp.provenance.fields["result.account_type"] + assert out.provenance_verified is None + assert out.text_agreement is None + + async def test_none_value_flags_none(self) -> None: + fp = FieldProvenance( + field_name="account_iban", + field_path="result.account_iban", + value=None, + sources=[_src("p1_l0", "IBAN blah")], + ) + resp = _response_with_provenance( + fields={"result.account_iban": fp}, + texts=["text"], + ) + step = ReliabilityStep() + resp = await step.process(_make_request(texts=["text"]), resp) + out = resp.provenance.fields["result.account_iban"] + assert out.provenance_verified is None + assert out.text_agreement is None + + async def test_short_value_text_agreement_skipped(self) -> None: + # Closing balance value < 10 → short numeric skip rule. + fp = FieldProvenance( + field_name="opening_balance", + field_path="result.opening_balance", + value=Decimal("5.00"), + sources=[_src("p1_l0", "balance 5.00")], + ) + resp = _response_with_provenance( + fields={"result.opening_balance": fp}, + texts=["balance 5.00"], + ) + step = ReliabilityStep() + resp = await step.process(_make_request(texts=["balance 5.00"]), resp) + out = resp.provenance.fields["result.opening_balance"] + assert out.provenance_verified is True # bbox cite still runs + assert out.text_agreement is None # short-value skip + + async def test_date_field_parses_both_sides(self) -> None: + fp = FieldProvenance( + field_name="statement_date", + field_path="result.statement_date", + value=date(2026, 3, 31), + sources=[_src("p1_l0", "Statement date 31.03.2026")], + ) + resp = _response_with_provenance( + fields={"result.statement_date": fp}, + texts=["Statement date 2026-03-31"], + ) + step = ReliabilityStep() + resp = await step.process(_make_request(texts=["Statement date 2026-03-31"]), resp) + out = resp.provenance.fields["result.statement_date"] + assert out.provenance_verified is True + assert out.text_agreement is True + + async def test_iban_field_whitespace_ignored(self) -> None: + fp = FieldProvenance( + field_name="account_iban", + field_path="result.account_iban", + value="DE89370400440532013000", + sources=[_src("p1_l0", "IBAN DE89 3704 0044 0532 0130 00")], + ) + resp = _response_with_provenance( + fields={"result.account_iban": fp}, + texts=["IBAN DE89 3704 0044 0532 0130 00"], + ) + step = ReliabilityStep() + resp = await step.process(_make_request(texts=["IBAN DE89 3704 0044 0532 0130 00"]), resp) + out = resp.provenance.fields["result.account_iban"] + assert out.provenance_verified is True + assert out.text_agreement is True + + async def test_disagreeing_snippet_sets_false(self) -> None: + fp = FieldProvenance( + field_name="bank_name", + field_path="result.bank_name", + value="DKB", + sources=[_src("p1_l0", "Commerzbank")], + ) + resp = _response_with_provenance( + fields={"result.bank_name": fp}, + texts=["Commerzbank header"], + ) + step = ReliabilityStep() + resp = await step.process(_make_request(texts=["Commerzbank header"]), resp) + out = resp.provenance.fields["result.bank_name"] + assert out.provenance_verified is False + assert out.text_agreement is False + + +class TestCounters: + async def test_quality_metrics_counters_written(self) -> None: + fp_ok = FieldProvenance( + field_name="bank_name", + field_path="result.bank_name", + value="DKB", + sources=[_src("p1_l0", "DKB")], + ) + fp_bad = FieldProvenance( + field_name="currency", + field_path="result.currency", + value="EUR", + sources=[_src("p1_l1", "nothing to see")], + ) + fp_literal = FieldProvenance( + field_name="account_type", + field_path="result.account_type", + value="checking", + sources=[_src("p1_l2", "anything")], + ) + resp = _response_with_provenance( + fields={ + "result.bank_name": fp_ok, + "result.currency": fp_bad, + "result.account_type": fp_literal, + }, + texts=["DKB statement"], + ) + step = ReliabilityStep() + resp = await step.process(_make_request(texts=["DKB statement"]), resp) + + qm = resp.provenance.quality_metrics + # bank_name verified+agree (2 flags), others not. + assert qm["verified_fields"] == 1 + assert qm["text_agreement_fields"] == 1