feat(pipeline): ReliabilityStep — writes reliability flags (spec §6)
Thin wrapper around ix.provenance.apply_reliability_flags. Validate skips entirely when include_provenance is off OR when no provenance data was built (text-only request, etc.). Process reads context.texts + context.use_case_response and lets the verifier mutate the FieldProvenance entries + fill quality_metrics counters in place. 11 unit tests in tests/unit/test_reliability_step.py cover: validate skips on flag off / missing provenance, runs otherwise; per-type flag behaviour (string verified + text_agreement, Literal -> None, None value -> None, short numeric -> text_agreement None, date with both sides parsed, IBAN whitespace-insensitive, disagreement -> False); quality_metrics verified_fields / text_agreement_fields counters. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
6d9c239e82
commit
132f110463
2 changed files with 306 additions and 0 deletions
56
src/ix/pipeline/reliability_step.py
Normal file
56
src/ix/pipeline/reliability_step.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""ReliabilityStep — writes provenance_verified + text_agreement (spec §6).
|
||||
|
||||
Runs after :class:`~ix.pipeline.genai_step.GenAIStep`. Skips entirely
|
||||
when provenance is off OR when no provenance data was built (OCR-skipped
|
||||
text-only request, for example). Otherwise delegates to
|
||||
:func:`~ix.provenance.apply_reliability_flags`, which mutates each
|
||||
:class:`~ix.contracts.FieldProvenance` in place and fills the two
|
||||
summary counters (``verified_fields``, ``text_agreement_fields``) on
|
||||
``quality_metrics``.
|
||||
|
||||
No own dispatch logic — everything interesting lives in the normalisers
|
||||
+ verifier modules and is unit-tested there.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ix.contracts import RequestIX, ResponseIX
|
||||
from ix.pipeline.step import Step
|
||||
from ix.provenance import apply_reliability_flags
|
||||
|
||||
|
||||
class ReliabilityStep(Step):
|
||||
"""Fills per-field reliability flags on ``response.provenance``."""
|
||||
|
||||
async def validate(self, request_ix: RequestIX, response_ix: ResponseIX) -> bool:
|
||||
if not request_ix.options.provenance.include_provenance:
|
||||
return False
|
||||
return response_ix.provenance is not None
|
||||
|
||||
async def process(
|
||||
self, request_ix: RequestIX, response_ix: ResponseIX
|
||||
) -> ResponseIX:
|
||||
assert response_ix.provenance is not None # validate() guarantees
|
||||
|
||||
ctx = response_ix.context
|
||||
texts: list[str] = (
|
||||
list(getattr(ctx, "texts", []) or []) if ctx is not None else []
|
||||
)
|
||||
use_case_response_cls = cast(
|
||||
type[BaseModel],
|
||||
getattr(ctx, "use_case_response", None) if ctx is not None else None,
|
||||
)
|
||||
|
||||
apply_reliability_flags(
|
||||
provenance_data=response_ix.provenance,
|
||||
use_case_response=use_case_response_cls,
|
||||
texts=texts,
|
||||
)
|
||||
return response_ix
|
||||
|
||||
|
||||
__all__ = ["ReliabilityStep"]
|
||||
250
tests/unit/test_reliability_step.py
Normal file
250
tests/unit/test_reliability_step.py
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
"""Tests for :class:`ix.pipeline.reliability_step.ReliabilityStep` (spec §6)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from ix.contracts import (
|
||||
BoundingBox,
|
||||
Context,
|
||||
ExtractionSource,
|
||||
FieldProvenance,
|
||||
OCROptions,
|
||||
Options,
|
||||
ProvenanceData,
|
||||
ProvenanceOptions,
|
||||
RequestIX,
|
||||
ResponseIX,
|
||||
)
|
||||
from ix.contracts.response import _InternalContext
|
||||
from ix.pipeline.reliability_step import ReliabilityStep
|
||||
from ix.use_cases.bank_statement_header import BankStatementHeader
|
||||
|
||||
|
||||
def _src(
|
||||
segment_id: str,
|
||||
text: str,
|
||||
page: int = 1,
|
||||
bbox: list[float] | None = None,
|
||||
) -> ExtractionSource:
|
||||
return ExtractionSource(
|
||||
page_number=page,
|
||||
file_index=0,
|
||||
bounding_box=BoundingBox(coordinates=bbox or [0, 0, 1, 0, 1, 1, 0, 1]),
|
||||
text_snippet=text,
|
||||
relevance_score=1.0,
|
||||
segment_id=segment_id,
|
||||
)
|
||||
|
||||
|
||||
def _make_request(
|
||||
include_provenance: bool = True, texts: list[str] | None = None
|
||||
) -> RequestIX:
|
||||
return RequestIX(
|
||||
use_case="bank_statement_header",
|
||||
ix_client_id="test",
|
||||
request_id="r-1",
|
||||
context=Context(files=[], texts=texts or []),
|
||||
options=Options(
|
||||
ocr=OCROptions(),
|
||||
provenance=ProvenanceOptions(include_provenance=include_provenance),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _response_with_provenance(
|
||||
fields: dict[str, FieldProvenance],
|
||||
texts: list[str] | None = None,
|
||||
) -> ResponseIX:
|
||||
resp = ResponseIX()
|
||||
resp.provenance = ProvenanceData(
|
||||
fields=fields,
|
||||
quality_metrics={},
|
||||
segment_count=10,
|
||||
granularity="line",
|
||||
)
|
||||
resp.context = _InternalContext(
|
||||
texts=texts or [],
|
||||
use_case_response=BankStatementHeader,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
class TestValidate:
|
||||
async def test_skipped_when_provenance_off(self) -> None:
|
||||
step = ReliabilityStep()
|
||||
req = _make_request(include_provenance=False)
|
||||
resp = _response_with_provenance(fields={})
|
||||
assert await step.validate(req, resp) is False
|
||||
|
||||
async def test_skipped_when_no_provenance_data(self) -> None:
|
||||
step = ReliabilityStep()
|
||||
req = _make_request(include_provenance=True)
|
||||
resp = ResponseIX()
|
||||
assert await step.validate(req, resp) is False
|
||||
|
||||
async def test_runs_when_provenance_data_present(self) -> None:
|
||||
step = ReliabilityStep()
|
||||
req = _make_request(include_provenance=True)
|
||||
resp = _response_with_provenance(fields={})
|
||||
assert await step.validate(req, resp) is True
|
||||
|
||||
|
||||
class TestProcessFlags:
|
||||
async def test_string_field_verified_and_text_agreement(self) -> None:
|
||||
fp = FieldProvenance(
|
||||
field_name="bank_name",
|
||||
field_path="result.bank_name",
|
||||
value="DKB",
|
||||
sources=[_src("p1_l0", "DKB")],
|
||||
)
|
||||
resp = _response_with_provenance(
|
||||
fields={"result.bank_name": fp},
|
||||
texts=["DKB statement content"],
|
||||
)
|
||||
step = ReliabilityStep()
|
||||
resp = await step.process(_make_request(texts=["DKB statement content"]), resp)
|
||||
out = resp.provenance.fields["result.bank_name"]
|
||||
assert out.provenance_verified is True
|
||||
assert out.text_agreement is True
|
||||
|
||||
async def test_literal_field_flags_none(self) -> None:
|
||||
fp = FieldProvenance(
|
||||
field_name="account_type",
|
||||
field_path="result.account_type",
|
||||
value="checking",
|
||||
sources=[_src("p1_l0", "anything")],
|
||||
)
|
||||
resp = _response_with_provenance(
|
||||
fields={"result.account_type": fp},
|
||||
texts=["some text"],
|
||||
)
|
||||
step = ReliabilityStep()
|
||||
resp = await step.process(_make_request(texts=["some text"]), resp)
|
||||
out = resp.provenance.fields["result.account_type"]
|
||||
assert out.provenance_verified is None
|
||||
assert out.text_agreement is None
|
||||
|
||||
async def test_none_value_flags_none(self) -> None:
|
||||
fp = FieldProvenance(
|
||||
field_name="account_iban",
|
||||
field_path="result.account_iban",
|
||||
value=None,
|
||||
sources=[_src("p1_l0", "IBAN blah")],
|
||||
)
|
||||
resp = _response_with_provenance(
|
||||
fields={"result.account_iban": fp},
|
||||
texts=["text"],
|
||||
)
|
||||
step = ReliabilityStep()
|
||||
resp = await step.process(_make_request(texts=["text"]), resp)
|
||||
out = resp.provenance.fields["result.account_iban"]
|
||||
assert out.provenance_verified is None
|
||||
assert out.text_agreement is None
|
||||
|
||||
async def test_short_value_text_agreement_skipped(self) -> None:
|
||||
# Closing balance value < 10 → short numeric skip rule.
|
||||
fp = FieldProvenance(
|
||||
field_name="opening_balance",
|
||||
field_path="result.opening_balance",
|
||||
value=Decimal("5.00"),
|
||||
sources=[_src("p1_l0", "balance 5.00")],
|
||||
)
|
||||
resp = _response_with_provenance(
|
||||
fields={"result.opening_balance": fp},
|
||||
texts=["balance 5.00"],
|
||||
)
|
||||
step = ReliabilityStep()
|
||||
resp = await step.process(_make_request(texts=["balance 5.00"]), resp)
|
||||
out = resp.provenance.fields["result.opening_balance"]
|
||||
assert out.provenance_verified is True # bbox cite still runs
|
||||
assert out.text_agreement is None # short-value skip
|
||||
|
||||
async def test_date_field_parses_both_sides(self) -> None:
|
||||
fp = FieldProvenance(
|
||||
field_name="statement_date",
|
||||
field_path="result.statement_date",
|
||||
value=date(2026, 3, 31),
|
||||
sources=[_src("p1_l0", "Statement date 31.03.2026")],
|
||||
)
|
||||
resp = _response_with_provenance(
|
||||
fields={"result.statement_date": fp},
|
||||
texts=["Statement date 2026-03-31"],
|
||||
)
|
||||
step = ReliabilityStep()
|
||||
resp = await step.process(_make_request(texts=["Statement date 2026-03-31"]), resp)
|
||||
out = resp.provenance.fields["result.statement_date"]
|
||||
assert out.provenance_verified is True
|
||||
assert out.text_agreement is True
|
||||
|
||||
async def test_iban_field_whitespace_ignored(self) -> None:
|
||||
fp = FieldProvenance(
|
||||
field_name="account_iban",
|
||||
field_path="result.account_iban",
|
||||
value="DE89370400440532013000",
|
||||
sources=[_src("p1_l0", "IBAN DE89 3704 0044 0532 0130 00")],
|
||||
)
|
||||
resp = _response_with_provenance(
|
||||
fields={"result.account_iban": fp},
|
||||
texts=["IBAN DE89 3704 0044 0532 0130 00"],
|
||||
)
|
||||
step = ReliabilityStep()
|
||||
resp = await step.process(_make_request(texts=["IBAN DE89 3704 0044 0532 0130 00"]), resp)
|
||||
out = resp.provenance.fields["result.account_iban"]
|
||||
assert out.provenance_verified is True
|
||||
assert out.text_agreement is True
|
||||
|
||||
async def test_disagreeing_snippet_sets_false(self) -> None:
|
||||
fp = FieldProvenance(
|
||||
field_name="bank_name",
|
||||
field_path="result.bank_name",
|
||||
value="DKB",
|
||||
sources=[_src("p1_l0", "Commerzbank")],
|
||||
)
|
||||
resp = _response_with_provenance(
|
||||
fields={"result.bank_name": fp},
|
||||
texts=["Commerzbank header"],
|
||||
)
|
||||
step = ReliabilityStep()
|
||||
resp = await step.process(_make_request(texts=["Commerzbank header"]), resp)
|
||||
out = resp.provenance.fields["result.bank_name"]
|
||||
assert out.provenance_verified is False
|
||||
assert out.text_agreement is False
|
||||
|
||||
|
||||
class TestCounters:
|
||||
async def test_quality_metrics_counters_written(self) -> None:
|
||||
fp_ok = FieldProvenance(
|
||||
field_name="bank_name",
|
||||
field_path="result.bank_name",
|
||||
value="DKB",
|
||||
sources=[_src("p1_l0", "DKB")],
|
||||
)
|
||||
fp_bad = FieldProvenance(
|
||||
field_name="currency",
|
||||
field_path="result.currency",
|
||||
value="EUR",
|
||||
sources=[_src("p1_l1", "nothing to see")],
|
||||
)
|
||||
fp_literal = FieldProvenance(
|
||||
field_name="account_type",
|
||||
field_path="result.account_type",
|
||||
value="checking",
|
||||
sources=[_src("p1_l2", "anything")],
|
||||
)
|
||||
resp = _response_with_provenance(
|
||||
fields={
|
||||
"result.bank_name": fp_ok,
|
||||
"result.currency": fp_bad,
|
||||
"result.account_type": fp_literal,
|
||||
},
|
||||
texts=["DKB statement"],
|
||||
)
|
||||
step = ReliabilityStep()
|
||||
resp = await step.process(_make_request(texts=["DKB statement"]), resp)
|
||||
|
||||
qm = resp.provenance.quality_metrics
|
||||
# bank_name verified+agree (2 flags), others not.
|
||||
assert qm["verified_fields"] == 1
|
||||
assert qm["text_agreement_fields"] == 1
|
||||
Loading…
Reference in a new issue