infoxtractor/tests/unit/test_genai_step.py
Dirk Riemann abee9cea7b
All checks were successful
tests / test (push) Successful in 1m14s
tests / test (pull_request) Successful in 1m10s
feat(pipeline): GenAIStep — LLM call + provenance mapping (spec §6.3, §7, §9.2)
Assembles the prompt, picks the structured-output schema, calls the
injected GenAIClient, and maps any emitted segment_citations into
response.provenance. Reliability flags stay None here; ReliabilityStep
fills them in Task 2.7.

- System prompt = use_case.system_prompt + (provenance-on) the verbatim
  citation instruction from spec §9.2.
- User text = SegmentIndex.to_prompt_text([p1_l0] style) when provenance
  is on, else plain OCR flat text + texts joined.
- Response schema = UseCaseResponse directly, or a runtime
  create_model("ProvenanceWrappedResponse", result=(UCR, ...),
  segment_citations=(list[SegmentCitation], Field(default_factory=list)))
  when provenance is on.
- Model = request override -> use-case default.
- Failure modes: httpx / connection / timeout errors -> IX_002_000;
  pydantic.ValidationError -> IX_002_001.
- Writes ix_result.result + ix_result.meta_data (model_name +
  token_usage); builds response.provenance via
  map_segment_refs_to_provenance when provenance is on.

17 unit tests in tests/unit/test_genai_step.py cover validate
(ocr_only skip, empty -> IX_001_000, text-only, ocr-text path), process
happy path, system-prompt shape with/without citation instruction, user
text tagged vs. plain, response schema plain vs. wrapped, provenance
mapping, error mapping (IX_002_000 + IX_002_001), and model selection
(request override + use-case default).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-18 11:18:44 +02:00

378 lines
14 KiB
Python

"""Tests for :class:`ix.pipeline.genai_step.GenAIStep` (spec §6.3, §7, §9.2)."""
from __future__ import annotations
from typing import Any
import httpx
import pytest
from pydantic import BaseModel, ValidationError
from ix.contracts import (
Context,
GenAIOptions,
Line,
OCRDetails,
OCROptions,
OCRResult,
Options,
Page,
ProvenanceData,
ProvenanceOptions,
RequestIX,
ResponseIX,
SegmentCitation,
)
from ix.contracts.response import _InternalContext
from ix.errors import IXErrorCode, IXException
from ix.genai import FakeGenAIClient, GenAIInvocationResult, GenAIUsage
from ix.pipeline.genai_step import GenAIStep
from ix.segmentation import PageMetadata, SegmentIndex
from ix.use_cases.bank_statement_header import BankStatementHeader
from ix.use_cases.bank_statement_header import Request as BankReq
def _make_request(
*,
use_ocr: bool = True,
ocr_only: bool = False,
include_provenance: bool = True,
model_name: str | None = None,
) -> RequestIX:
return RequestIX(
use_case="bank_statement_header",
ix_client_id="test",
request_id="r-1",
context=Context(files=[], texts=[]),
options=Options(
ocr=OCROptions(use_ocr=use_ocr, ocr_only=ocr_only),
gen_ai=GenAIOptions(gen_ai_model_name=model_name),
provenance=ProvenanceOptions(
include_provenance=include_provenance,
max_sources_per_field=5,
),
),
)
def _ocr_with_lines(lines: list[str]) -> OCRResult:
return OCRResult(
result=OCRDetails(
text="\n".join(lines),
pages=[
Page(
page_no=1,
width=100.0,
height=200.0,
lines=[
Line(text=t, bounding_box=[0, i * 10, 10, i * 10, 10, i * 10 + 5, 0, i * 10 + 5])
for i, t in enumerate(lines)
],
)
],
)
)
def _response_with_segment_index(
lines: list[str], texts: list[str] | None = None
) -> ResponseIX:
ocr = _ocr_with_lines(lines)
resp = ResponseIX(ocr_result=ocr)
seg_idx = SegmentIndex.build(
ocr_result=ocr,
granularity="line",
pages_metadata=[PageMetadata(file_index=0)],
)
resp.context = _InternalContext(
use_case_request=BankReq(),
use_case_response=BankStatementHeader,
segment_index=seg_idx,
texts=texts or [],
pages=ocr.result.pages,
page_metadata=[PageMetadata(file_index=0)],
)
return resp
class CapturingClient:
"""Records the request_kwargs + response_schema handed to invoke()."""
def __init__(self, parsed: Any) -> None:
self._parsed = parsed
self.request_kwargs: dict[str, Any] | None = None
self.response_schema: type[BaseModel] | None = None
async def invoke(
self,
request_kwargs: dict[str, Any],
response_schema: type[BaseModel],
) -> GenAIInvocationResult:
self.request_kwargs = request_kwargs
self.response_schema = response_schema
return GenAIInvocationResult(
parsed=self._parsed,
usage=GenAIUsage(prompt_tokens=5, completion_tokens=7),
model_name="captured-model",
)
class TestValidate:
async def test_ocr_only_skips(self) -> None:
step = GenAIStep(
genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR"))
)
req = _make_request(ocr_only=True)
resp = _response_with_segment_index(lines=["hello"])
assert await step.validate(req, resp) is False
async def test_empty_context_raises_IX_001_000(self) -> None:
step = GenAIStep(
genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR"))
)
req = _make_request()
resp = ResponseIX(ocr_result=OCRResult(result=OCRDetails(text="")))
resp.context = _InternalContext(
use_case_request=BankReq(),
use_case_response=BankStatementHeader,
texts=[],
)
with pytest.raises(IXException) as ei:
await step.validate(req, resp)
assert ei.value.code is IXErrorCode.IX_001_000
async def test_runs_when_texts_only(self) -> None:
step = GenAIStep(
genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR"))
)
req = _make_request()
resp = ResponseIX(ocr_result=OCRResult(result=OCRDetails(text="")))
resp.context = _InternalContext(
use_case_request=BankReq(),
use_case_response=BankStatementHeader,
texts=["some paperless text"],
)
assert await step.validate(req, resp) is True
async def test_runs_when_ocr_text_present(self) -> None:
step = GenAIStep(
genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR"))
)
req = _make_request()
resp = _response_with_segment_index(lines=["hello"])
assert await step.validate(req, resp) is True
class TestProcessBasic:
async def test_writes_ix_result_and_meta(self) -> None:
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
client = CapturingClient(parsed=parsed)
step = GenAIStep(genai_client=client)
req = _make_request(include_provenance=False)
resp = _response_with_segment_index(lines=["hello"])
resp = await step.process(req, resp)
assert resp.ix_result.result["bank_name"] == "DKB"
assert resp.ix_result.result["currency"] == "EUR"
assert resp.ix_result.meta_data["model_name"] == "captured-model"
assert resp.ix_result.meta_data["token_usage"]["prompt_tokens"] == 5
assert resp.ix_result.meta_data["token_usage"]["completion_tokens"] == 7
class TestSystemPromptAssembly:
async def test_citation_instruction_appended_when_provenance_on(self) -> None:
parsed_wrapped: Any = _WrappedResponse(
result=BankStatementHeader(bank_name="DKB", currency="EUR"),
segment_citations=[],
)
client = CapturingClient(parsed=parsed_wrapped)
step = GenAIStep(genai_client=client)
req = _make_request(include_provenance=True)
resp = _response_with_segment_index(lines=["hello"])
await step.process(req, resp)
messages = client.request_kwargs["messages"] # type: ignore[index]
system = messages[0]["content"]
# Use-case system prompt is always there.
assert "extract header metadata" in system
# Citation instruction added.
assert "segment_citations" in system
assert "value_segment_ids" in system
async def test_citation_instruction_absent_when_provenance_off(self) -> None:
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
client = CapturingClient(parsed=parsed)
step = GenAIStep(genai_client=client)
req = _make_request(include_provenance=False)
resp = _response_with_segment_index(lines=["hello"])
await step.process(req, resp)
messages = client.request_kwargs["messages"] # type: ignore[index]
system = messages[0]["content"]
assert "segment_citations" not in system
class TestUserTextFormat:
async def test_tagged_prompt_when_provenance_on(self) -> None:
parsed_wrapped: Any = _WrappedResponse(
result=BankStatementHeader(bank_name="DKB", currency="EUR"),
segment_citations=[],
)
client = CapturingClient(parsed=parsed_wrapped)
step = GenAIStep(genai_client=client)
req = _make_request(include_provenance=True)
resp = _response_with_segment_index(lines=["alpha line", "beta line"])
await step.process(req, resp)
user_content = client.request_kwargs["messages"][1]["content"] # type: ignore[index]
assert "[p1_l0] alpha line" in user_content
assert "[p1_l1] beta line" in user_content
async def test_plain_prompt_when_provenance_off(self) -> None:
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
client = CapturingClient(parsed=parsed)
step = GenAIStep(genai_client=client)
req = _make_request(include_provenance=False)
resp = _response_with_segment_index(lines=["alpha line", "beta line"])
await step.process(req, resp)
user_content = client.request_kwargs["messages"][1]["content"] # type: ignore[index]
assert "[p1_l0]" not in user_content
assert "alpha line" in user_content
assert "beta line" in user_content
class TestResponseSchemaChoice:
async def test_plain_schema_when_provenance_off(self) -> None:
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
client = CapturingClient(parsed=parsed)
step = GenAIStep(genai_client=client)
req = _make_request(include_provenance=False)
resp = _response_with_segment_index(lines=["hello"])
await step.process(req, resp)
assert client.response_schema is BankStatementHeader
async def test_wrapped_schema_when_provenance_on(self) -> None:
parsed_wrapped: Any = _WrappedResponse(
result=BankStatementHeader(bank_name="DKB", currency="EUR"),
segment_citations=[],
)
client = CapturingClient(parsed=parsed_wrapped)
step = GenAIStep(genai_client=client)
req = _make_request(include_provenance=True)
resp = _response_with_segment_index(lines=["hello"])
await step.process(req, resp)
schema = client.response_schema
assert schema is not None
field_names = set(schema.model_fields.keys())
assert field_names == {"result", "segment_citations"}
class TestProvenanceMapping:
async def test_provenance_populated_from_citations(self) -> None:
parsed_wrapped: Any = _WrappedResponse(
result=BankStatementHeader(bank_name="DKB", currency="EUR"),
segment_citations=[
SegmentCitation(
field_path="result.bank_name",
value_segment_ids=["p1_l0"],
context_segment_ids=[],
),
],
)
client = CapturingClient(parsed=parsed_wrapped)
step = GenAIStep(genai_client=client)
req = _make_request(include_provenance=True)
resp = _response_with_segment_index(lines=["DKB"])
resp = await step.process(req, resp)
assert isinstance(resp.provenance, ProvenanceData)
fields = resp.provenance.fields
assert "result.bank_name" in fields
fp = fields["result.bank_name"]
assert fp.value == "DKB"
assert len(fp.sources) == 1
assert fp.sources[0].segment_id == "p1_l0"
# Reliability flags are NOT set here — ReliabilityStep does that.
assert fp.provenance_verified is None
assert fp.text_agreement is None
class TestErrorHandling:
async def test_network_error_maps_to_IX_002_000(self) -> None:
err = httpx.ConnectError("refused")
client = FakeGenAIClient(
parsed=BankStatementHeader(bank_name="x", currency="EUR"),
raise_on_call=err,
)
step = GenAIStep(genai_client=client)
req = _make_request(include_provenance=False)
resp = _response_with_segment_index(lines=["hello"])
with pytest.raises(IXException) as ei:
await step.process(req, resp)
assert ei.value.code is IXErrorCode.IX_002_000
async def test_timeout_maps_to_IX_002_000(self) -> None:
err = httpx.ReadTimeout("slow")
client = FakeGenAIClient(
parsed=BankStatementHeader(bank_name="x", currency="EUR"),
raise_on_call=err,
)
step = GenAIStep(genai_client=client)
req = _make_request(include_provenance=False)
resp = _response_with_segment_index(lines=["hello"])
with pytest.raises(IXException) as ei:
await step.process(req, resp)
assert ei.value.code is IXErrorCode.IX_002_000
async def test_validation_error_maps_to_IX_002_001(self) -> None:
class _M(BaseModel):
x: int
try:
_M(x="not-an-int") # type: ignore[arg-type]
except ValidationError as err:
raise_err = err
client = FakeGenAIClient(
parsed=BankStatementHeader(bank_name="x", currency="EUR"),
raise_on_call=raise_err,
)
step = GenAIStep(genai_client=client)
req = _make_request(include_provenance=False)
resp = _response_with_segment_index(lines=["hello"])
with pytest.raises(IXException) as ei:
await step.process(req, resp)
assert ei.value.code is IXErrorCode.IX_002_001
class TestModelSelection:
async def test_request_model_override_wins(self) -> None:
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
client = CapturingClient(parsed=parsed)
step = GenAIStep(genai_client=client)
req = _make_request(include_provenance=False, model_name="explicit-model")
resp = _response_with_segment_index(lines=["hello"])
await step.process(req, resp)
assert client.request_kwargs["model"] == "explicit-model" # type: ignore[index]
async def test_falls_back_to_use_case_default(self) -> None:
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
client = CapturingClient(parsed=parsed)
step = GenAIStep(genai_client=client)
req = _make_request(include_provenance=False)
resp = _response_with_segment_index(lines=["hello"])
await step.process(req, resp)
# use-case default is gpt-oss:20b
assert client.request_kwargs["model"] == "gpt-oss:20b" # type: ignore[index]
# ----------------------------------------------------------------------------
# Helpers
class _WrappedResponse(BaseModel):
"""Stand-in for the runtime-created ProvenanceWrappedResponse."""
result: BankStatementHeader
segment_citations: list[SegmentCitation] = []