"""Tests for :class:`ix.pipeline.genai_step.GenAIStep` (spec §6.3, §7, §9.2).""" from __future__ import annotations from typing import Any import httpx import pytest from pydantic import BaseModel, ValidationError from ix.contracts import ( Context, GenAIOptions, Line, OCRDetails, OCROptions, OCRResult, Options, Page, ProvenanceData, ProvenanceOptions, RequestIX, ResponseIX, SegmentCitation, ) from ix.contracts.response import _InternalContext from ix.errors import IXErrorCode, IXException from ix.genai import FakeGenAIClient, GenAIInvocationResult, GenAIUsage from ix.pipeline.genai_step import GenAIStep from ix.segmentation import PageMetadata, SegmentIndex from ix.use_cases.bank_statement_header import BankStatementHeader from ix.use_cases.bank_statement_header import Request as BankReq def _make_request( *, use_ocr: bool = True, ocr_only: bool = False, include_provenance: bool = True, model_name: str | None = None, ) -> RequestIX: return RequestIX( use_case="bank_statement_header", ix_client_id="test", request_id="r-1", context=Context(files=[], texts=[]), options=Options( ocr=OCROptions(use_ocr=use_ocr, ocr_only=ocr_only), gen_ai=GenAIOptions(gen_ai_model_name=model_name), provenance=ProvenanceOptions( include_provenance=include_provenance, max_sources_per_field=5, ), ), ) def _ocr_with_lines(lines: list[str]) -> OCRResult: return OCRResult( result=OCRDetails( text="\n".join(lines), pages=[ Page( page_no=1, width=100.0, height=200.0, lines=[ Line(text=t, bounding_box=[0, i * 10, 10, i * 10, 10, i * 10 + 5, 0, i * 10 + 5]) for i, t in enumerate(lines) ], ) ], ) ) def _response_with_segment_index( lines: list[str], texts: list[str] | None = None ) -> ResponseIX: ocr = _ocr_with_lines(lines) resp = ResponseIX(ocr_result=ocr) seg_idx = SegmentIndex.build( ocr_result=ocr, granularity="line", pages_metadata=[PageMetadata(file_index=0)], ) resp.context = _InternalContext( use_case_request=BankReq(), use_case_response=BankStatementHeader, segment_index=seg_idx, texts=texts or [], pages=ocr.result.pages, page_metadata=[PageMetadata(file_index=0)], ) return resp class CapturingClient: """Records the request_kwargs + response_schema handed to invoke().""" def __init__(self, parsed: Any) -> None: self._parsed = parsed self.request_kwargs: dict[str, Any] | None = None self.response_schema: type[BaseModel] | None = None async def invoke( self, request_kwargs: dict[str, Any], response_schema: type[BaseModel], ) -> GenAIInvocationResult: self.request_kwargs = request_kwargs self.response_schema = response_schema return GenAIInvocationResult( parsed=self._parsed, usage=GenAIUsage(prompt_tokens=5, completion_tokens=7), model_name="captured-model", ) class TestValidate: async def test_ocr_only_skips(self) -> None: step = GenAIStep( genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR")) ) req = _make_request(ocr_only=True) resp = _response_with_segment_index(lines=["hello"]) assert await step.validate(req, resp) is False async def test_empty_context_raises_IX_001_000(self) -> None: step = GenAIStep( genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR")) ) req = _make_request() resp = ResponseIX(ocr_result=OCRResult(result=OCRDetails(text=""))) resp.context = _InternalContext( use_case_request=BankReq(), use_case_response=BankStatementHeader, texts=[], ) with pytest.raises(IXException) as ei: await step.validate(req, resp) assert ei.value.code is IXErrorCode.IX_001_000 async def test_runs_when_texts_only(self) -> None: step = GenAIStep( genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR")) ) req = _make_request() resp = ResponseIX(ocr_result=OCRResult(result=OCRDetails(text=""))) resp.context = _InternalContext( use_case_request=BankReq(), use_case_response=BankStatementHeader, texts=["some paperless text"], ) assert await step.validate(req, resp) is True async def test_runs_when_ocr_text_present(self) -> None: step = GenAIStep( genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR")) ) req = _make_request() resp = _response_with_segment_index(lines=["hello"]) assert await step.validate(req, resp) is True class TestProcessBasic: async def test_writes_ix_result_and_meta(self) -> None: parsed = BankStatementHeader(bank_name="DKB", currency="EUR") client = CapturingClient(parsed=parsed) step = GenAIStep(genai_client=client) req = _make_request(include_provenance=False) resp = _response_with_segment_index(lines=["hello"]) resp = await step.process(req, resp) assert resp.ix_result.result["bank_name"] == "DKB" assert resp.ix_result.result["currency"] == "EUR" assert resp.ix_result.meta_data["model_name"] == "captured-model" assert resp.ix_result.meta_data["token_usage"]["prompt_tokens"] == 5 assert resp.ix_result.meta_data["token_usage"]["completion_tokens"] == 7 class TestSystemPromptAssembly: async def test_citation_instruction_appended_when_provenance_on(self) -> None: parsed_wrapped: Any = _WrappedResponse( result=BankStatementHeader(bank_name="DKB", currency="EUR"), segment_citations=[], ) client = CapturingClient(parsed=parsed_wrapped) step = GenAIStep(genai_client=client) req = _make_request(include_provenance=True) resp = _response_with_segment_index(lines=["hello"]) await step.process(req, resp) messages = client.request_kwargs["messages"] # type: ignore[index] system = messages[0]["content"] # Use-case system prompt is always there. assert "extract header metadata" in system # Citation instruction added. assert "segment_citations" in system assert "value_segment_ids" in system async def test_citation_instruction_absent_when_provenance_off(self) -> None: parsed = BankStatementHeader(bank_name="DKB", currency="EUR") client = CapturingClient(parsed=parsed) step = GenAIStep(genai_client=client) req = _make_request(include_provenance=False) resp = _response_with_segment_index(lines=["hello"]) await step.process(req, resp) messages = client.request_kwargs["messages"] # type: ignore[index] system = messages[0]["content"] assert "segment_citations" not in system class TestUserTextFormat: async def test_tagged_prompt_when_provenance_on(self) -> None: parsed_wrapped: Any = _WrappedResponse( result=BankStatementHeader(bank_name="DKB", currency="EUR"), segment_citations=[], ) client = CapturingClient(parsed=parsed_wrapped) step = GenAIStep(genai_client=client) req = _make_request(include_provenance=True) resp = _response_with_segment_index(lines=["alpha line", "beta line"]) await step.process(req, resp) user_content = client.request_kwargs["messages"][1]["content"] # type: ignore[index] assert "[p1_l0] alpha line" in user_content assert "[p1_l1] beta line" in user_content async def test_plain_prompt_when_provenance_off(self) -> None: parsed = BankStatementHeader(bank_name="DKB", currency="EUR") client = CapturingClient(parsed=parsed) step = GenAIStep(genai_client=client) req = _make_request(include_provenance=False) resp = _response_with_segment_index(lines=["alpha line", "beta line"]) await step.process(req, resp) user_content = client.request_kwargs["messages"][1]["content"] # type: ignore[index] assert "[p1_l0]" not in user_content assert "alpha line" in user_content assert "beta line" in user_content class TestResponseSchemaChoice: async def test_plain_schema_when_provenance_off(self) -> None: parsed = BankStatementHeader(bank_name="DKB", currency="EUR") client = CapturingClient(parsed=parsed) step = GenAIStep(genai_client=client) req = _make_request(include_provenance=False) resp = _response_with_segment_index(lines=["hello"]) await step.process(req, resp) assert client.response_schema is BankStatementHeader async def test_wrapped_schema_when_provenance_on(self) -> None: parsed_wrapped: Any = _WrappedResponse( result=BankStatementHeader(bank_name="DKB", currency="EUR"), segment_citations=[], ) client = CapturingClient(parsed=parsed_wrapped) step = GenAIStep(genai_client=client) req = _make_request(include_provenance=True) resp = _response_with_segment_index(lines=["hello"]) await step.process(req, resp) schema = client.response_schema assert schema is not None field_names = set(schema.model_fields.keys()) assert field_names == {"result", "segment_citations"} class TestProvenanceMapping: async def test_provenance_populated_from_citations(self) -> None: parsed_wrapped: Any = _WrappedResponse( result=BankStatementHeader(bank_name="DKB", currency="EUR"), segment_citations=[ SegmentCitation( field_path="result.bank_name", value_segment_ids=["p1_l0"], context_segment_ids=[], ), ], ) client = CapturingClient(parsed=parsed_wrapped) step = GenAIStep(genai_client=client) req = _make_request(include_provenance=True) resp = _response_with_segment_index(lines=["DKB"]) resp = await step.process(req, resp) assert isinstance(resp.provenance, ProvenanceData) fields = resp.provenance.fields assert "result.bank_name" in fields fp = fields["result.bank_name"] assert fp.value == "DKB" assert len(fp.sources) == 1 assert fp.sources[0].segment_id == "p1_l0" # Reliability flags are NOT set here — ReliabilityStep does that. assert fp.provenance_verified is None assert fp.text_agreement is None class TestErrorHandling: async def test_network_error_maps_to_IX_002_000(self) -> None: err = httpx.ConnectError("refused") client = FakeGenAIClient( parsed=BankStatementHeader(bank_name="x", currency="EUR"), raise_on_call=err, ) step = GenAIStep(genai_client=client) req = _make_request(include_provenance=False) resp = _response_with_segment_index(lines=["hello"]) with pytest.raises(IXException) as ei: await step.process(req, resp) assert ei.value.code is IXErrorCode.IX_002_000 async def test_timeout_maps_to_IX_002_000(self) -> None: err = httpx.ReadTimeout("slow") client = FakeGenAIClient( parsed=BankStatementHeader(bank_name="x", currency="EUR"), raise_on_call=err, ) step = GenAIStep(genai_client=client) req = _make_request(include_provenance=False) resp = _response_with_segment_index(lines=["hello"]) with pytest.raises(IXException) as ei: await step.process(req, resp) assert ei.value.code is IXErrorCode.IX_002_000 async def test_validation_error_maps_to_IX_002_001(self) -> None: class _M(BaseModel): x: int try: _M(x="not-an-int") # type: ignore[arg-type] except ValidationError as err: raise_err = err client = FakeGenAIClient( parsed=BankStatementHeader(bank_name="x", currency="EUR"), raise_on_call=raise_err, ) step = GenAIStep(genai_client=client) req = _make_request(include_provenance=False) resp = _response_with_segment_index(lines=["hello"]) with pytest.raises(IXException) as ei: await step.process(req, resp) assert ei.value.code is IXErrorCode.IX_002_001 class TestModelSelection: async def test_request_model_override_wins(self) -> None: parsed = BankStatementHeader(bank_name="DKB", currency="EUR") client = CapturingClient(parsed=parsed) step = GenAIStep(genai_client=client) req = _make_request(include_provenance=False, model_name="explicit-model") resp = _response_with_segment_index(lines=["hello"]) await step.process(req, resp) assert client.request_kwargs["model"] == "explicit-model" # type: ignore[index] async def test_falls_back_to_use_case_default(self) -> None: parsed = BankStatementHeader(bank_name="DKB", currency="EUR") client = CapturingClient(parsed=parsed) step = GenAIStep(genai_client=client) req = _make_request(include_provenance=False) resp = _response_with_segment_index(lines=["hello"]) await step.process(req, resp) # use-case default is qwen3:14b assert client.request_kwargs["model"] == "qwen3:14b" # type: ignore[index] # ---------------------------------------------------------------------------- # Helpers class _WrappedResponse(BaseModel): """Stand-in for the runtime-created ProvenanceWrappedResponse.""" result: BankStatementHeader segment_citations: list[SegmentCitation] = []