diff --git a/src/ix/pipeline/genai_step.py b/src/ix/pipeline/genai_step.py new file mode 100644 index 0000000..a5c061d --- /dev/null +++ b/src/ix/pipeline/genai_step.py @@ -0,0 +1,216 @@ +"""GenAIStep — assemble prompt, call LLM, map provenance (spec §6.3, §7, §9.2). + +Runs after :class:`~ix.pipeline.ocr_step.OCRStep`. Builds the chat-style +``request_kwargs`` (messages + model name), picks the structured-output +schema (plain ``UseCaseResponse`` or a runtime +``ProvenanceWrappedResponse(result=..., segment_citations=...)`` when +provenance is on), hands both to the injected :class:`GenAIClient`, and +writes the parsed payload onto ``response_ix.ix_result``. + +When provenance is on, the LLM-emitted ``segment_citations`` flow into +:func:`~ix.provenance.map_segment_refs_to_provenance` to build +``response_ix.provenance``. The per-field reliability flags +(``provenance_verified`` / ``text_agreement``) stay ``None`` here — they +land in :class:`~ix.pipeline.reliability_step.ReliabilityStep`. + +Failure modes: + +* Network / timeout / non-2xx surfaced by the client → ``IX_002_000``. +* :class:`pydantic.ValidationError` (structured output didn't match the + schema) → ``IX_002_001``. +""" + +from __future__ import annotations + +from typing import Any, cast + +import httpx +from pydantic import BaseModel, Field, ValidationError, create_model + +from ix.contracts import RequestIX, ResponseIX, SegmentCitation +from ix.errors import IXErrorCode, IXException +from ix.genai.client import GenAIClient +from ix.pipeline.step import Step +from ix.provenance import map_segment_refs_to_provenance +from ix.segmentation import SegmentIndex + +# Verbatim from spec §9.2 (core-pipeline spec) — inserted after the +# use-case system prompt when provenance is on. +_CITATION_INSTRUCTION = ( + "For each extracted field, you must also populate the `segment_citations` list.\n" + "Each entry maps one field to the document segments that were its source.\n" + "Set `field_path` to the dot-separated JSON path of the field " + "(e.g. 'result.invoice_number').\n" + "Use two separate segment ID lists:\n" + "- `value_segment_ids`: segment IDs whose text directly contains the extracted " + "value (e.g. ['p1_l4'] for the line containing 'INV-001').\n" + "- `context_segment_ids`: segment IDs for surrounding label or anchor text that " + "helped you identify the field but does not contain the value itself " + "(e.g. ['p1_l3'] for a label like 'Invoice Number:'). Leave empty if there is " + "no distinct label.\n" + "Only use segment IDs that appear in the document text.\n" + "Omit fields for which you cannot identify a source segment." +) + + +class GenAIStep(Step): + """LLM extraction + (optional) provenance mapping.""" + + def __init__(self, genai_client: GenAIClient) -> None: + self._client = genai_client + + async def validate(self, request_ix: RequestIX, response_ix: ResponseIX) -> bool: + if request_ix.options.ocr.ocr_only: + return False + + ctx = response_ix.context + ocr_text = ( + response_ix.ocr_result.result.text + if response_ix.ocr_result is not None + else None + ) + texts = list(getattr(ctx, "texts", []) or []) if ctx is not None else [] + + if not (ocr_text and ocr_text.strip()) and not texts: + raise IXException(IXErrorCode.IX_001_000) + return True + + async def process( + self, request_ix: RequestIX, response_ix: ResponseIX + ) -> ResponseIX: + ctx = response_ix.context + assert ctx is not None, "SetupStep must populate response_ix.context" + use_case_request: Any = getattr(ctx, "use_case_request", None) + use_case_response_cls: type[BaseModel] = getattr(ctx, "use_case_response", None) + assert use_case_request is not None and use_case_response_cls is not None + + opts = request_ix.options + provenance_on = opts.provenance.include_provenance + + # 1. System prompt — use-case default + optional citation instruction. + system_prompt = use_case_request.system_prompt + if provenance_on: + system_prompt = f"{system_prompt}\n\n{_CITATION_INSTRUCTION}" + + # 2. User text — segment-tagged when provenance is on, else plain OCR + texts. + user_text = self._build_user_text(response_ix, provenance_on) + + # 3. Response schema — plain or wrapped. + response_schema = self._resolve_response_schema( + use_case_response_cls, provenance_on + ) + + # 4. Model selection — request override → use-case default. + model_name = ( + opts.gen_ai.gen_ai_model_name + or getattr(use_case_request, "default_model", None) + ) + + request_kwargs = { + "model": model_name, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_text}, + ], + } + + # 5. Call the backend, translate errors. + try: + result = await self._client.invoke( + request_kwargs=request_kwargs, + response_schema=response_schema, + ) + except ValidationError as exc: + raise IXException( + IXErrorCode.IX_002_001, + detail=f"{use_case_response_cls.__name__}: {exc}", + ) from exc + except (httpx.HTTPError, ConnectionError, TimeoutError) as exc: + raise IXException( + IXErrorCode.IX_002_000, + detail=f"{model_name}: {exc.__class__.__name__}: {exc}", + ) from exc + except IXException: + raise + + # 6. Split parsed output; write result + meta. + if provenance_on: + wrapped = result.parsed + extraction: BaseModel = wrapped.result + segment_citations: list[SegmentCitation] = list( + getattr(wrapped, "segment_citations", []) or [] + ) + else: + extraction = result.parsed + segment_citations = [] + + response_ix.ix_result.result = extraction.model_dump(mode="json") + response_ix.ix_result.meta_data = { + "model_name": result.model_name, + "token_usage": { + "prompt_tokens": result.usage.prompt_tokens, + "completion_tokens": result.usage.completion_tokens, + }, + } + + # 7. Provenance mapping — only the structural assembly. Reliability + # flags get written in ReliabilityStep. + if provenance_on: + seg_idx = cast(SegmentIndex, getattr(ctx, "segment_index", None)) + if seg_idx is None: + # No OCR was run (text-only request); skip provenance. + response_ix.provenance = None + else: + response_ix.provenance = map_segment_refs_to_provenance( + extraction_result={"result": response_ix.ix_result.result}, + segment_citations=segment_citations, + segment_index=seg_idx, + max_sources_per_field=opts.provenance.max_sources_per_field, + min_confidence=0.0, + include_bounding_boxes=True, + source_type="value_and_context", + ) + + return response_ix + + def _build_user_text(self, response_ix: ResponseIX, provenance_on: bool) -> str: + ctx = response_ix.context + assert ctx is not None + texts: list[str] = list(getattr(ctx, "texts", []) or []) + seg_idx: SegmentIndex | None = getattr(ctx, "segment_index", None) + + if provenance_on and seg_idx is not None: + return seg_idx.to_prompt_text(context_texts=texts) + + # Plain concat — OCR flat text + any extra paperless-style texts. + parts: list[str] = [] + ocr_text = ( + response_ix.ocr_result.result.text + if response_ix.ocr_result is not None + else None + ) + if ocr_text: + parts.append(ocr_text) + parts.extend(texts) + return "\n\n".join(p for p in parts if p) + + def _resolve_response_schema( + self, + use_case_response_cls: type[BaseModel], + provenance_on: bool, + ) -> type[BaseModel]: + if not provenance_on: + return use_case_response_cls + # Dynamic wrapper — one per call is fine; Pydantic caches the + # generated JSON schema internally. + return create_model( + "ProvenanceWrappedResponse", + result=(use_case_response_cls, ...), + segment_citations=( + list[SegmentCitation], + Field(default_factory=list), + ), + ) + + +__all__ = ["GenAIStep"] diff --git a/tests/unit/test_genai_step.py b/tests/unit/test_genai_step.py new file mode 100644 index 0000000..ec959f2 --- /dev/null +++ b/tests/unit/test_genai_step.py @@ -0,0 +1,378 @@ +"""Tests for :class:`ix.pipeline.genai_step.GenAIStep` (spec §6.3, §7, §9.2).""" + +from __future__ import annotations + +from typing import Any + +import httpx +import pytest +from pydantic import BaseModel, ValidationError + +from ix.contracts import ( + Context, + GenAIOptions, + Line, + OCRDetails, + OCROptions, + OCRResult, + Options, + Page, + ProvenanceData, + ProvenanceOptions, + RequestIX, + ResponseIX, + SegmentCitation, +) +from ix.contracts.response import _InternalContext +from ix.errors import IXErrorCode, IXException +from ix.genai import FakeGenAIClient, GenAIInvocationResult, GenAIUsage +from ix.pipeline.genai_step import GenAIStep +from ix.segmentation import PageMetadata, SegmentIndex +from ix.use_cases.bank_statement_header import BankStatementHeader +from ix.use_cases.bank_statement_header import Request as BankReq + + +def _make_request( + *, + use_ocr: bool = True, + ocr_only: bool = False, + include_provenance: bool = True, + model_name: str | None = None, +) -> RequestIX: + return RequestIX( + use_case="bank_statement_header", + ix_client_id="test", + request_id="r-1", + context=Context(files=[], texts=[]), + options=Options( + ocr=OCROptions(use_ocr=use_ocr, ocr_only=ocr_only), + gen_ai=GenAIOptions(gen_ai_model_name=model_name), + provenance=ProvenanceOptions( + include_provenance=include_provenance, + max_sources_per_field=5, + ), + ), + ) + + +def _ocr_with_lines(lines: list[str]) -> OCRResult: + return OCRResult( + result=OCRDetails( + text="\n".join(lines), + pages=[ + Page( + page_no=1, + width=100.0, + height=200.0, + lines=[ + Line(text=t, bounding_box=[0, i * 10, 10, i * 10, 10, i * 10 + 5, 0, i * 10 + 5]) + for i, t in enumerate(lines) + ], + ) + ], + ) + ) + + +def _response_with_segment_index( + lines: list[str], texts: list[str] | None = None +) -> ResponseIX: + ocr = _ocr_with_lines(lines) + resp = ResponseIX(ocr_result=ocr) + seg_idx = SegmentIndex.build( + ocr_result=ocr, + granularity="line", + pages_metadata=[PageMetadata(file_index=0)], + ) + resp.context = _InternalContext( + use_case_request=BankReq(), + use_case_response=BankStatementHeader, + segment_index=seg_idx, + texts=texts or [], + pages=ocr.result.pages, + page_metadata=[PageMetadata(file_index=0)], + ) + return resp + + +class CapturingClient: + """Records the request_kwargs + response_schema handed to invoke().""" + + def __init__(self, parsed: Any) -> None: + self._parsed = parsed + self.request_kwargs: dict[str, Any] | None = None + self.response_schema: type[BaseModel] | None = None + + async def invoke( + self, + request_kwargs: dict[str, Any], + response_schema: type[BaseModel], + ) -> GenAIInvocationResult: + self.request_kwargs = request_kwargs + self.response_schema = response_schema + return GenAIInvocationResult( + parsed=self._parsed, + usage=GenAIUsage(prompt_tokens=5, completion_tokens=7), + model_name="captured-model", + ) + + +class TestValidate: + async def test_ocr_only_skips(self) -> None: + step = GenAIStep( + genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR")) + ) + req = _make_request(ocr_only=True) + resp = _response_with_segment_index(lines=["hello"]) + assert await step.validate(req, resp) is False + + async def test_empty_context_raises_IX_001_000(self) -> None: + step = GenAIStep( + genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR")) + ) + req = _make_request() + resp = ResponseIX(ocr_result=OCRResult(result=OCRDetails(text=""))) + resp.context = _InternalContext( + use_case_request=BankReq(), + use_case_response=BankStatementHeader, + texts=[], + ) + with pytest.raises(IXException) as ei: + await step.validate(req, resp) + assert ei.value.code is IXErrorCode.IX_001_000 + + async def test_runs_when_texts_only(self) -> None: + step = GenAIStep( + genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR")) + ) + req = _make_request() + resp = ResponseIX(ocr_result=OCRResult(result=OCRDetails(text=""))) + resp.context = _InternalContext( + use_case_request=BankReq(), + use_case_response=BankStatementHeader, + texts=["some paperless text"], + ) + assert await step.validate(req, resp) is True + + async def test_runs_when_ocr_text_present(self) -> None: + step = GenAIStep( + genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR")) + ) + req = _make_request() + resp = _response_with_segment_index(lines=["hello"]) + assert await step.validate(req, resp) is True + + +class TestProcessBasic: + async def test_writes_ix_result_and_meta(self) -> None: + parsed = BankStatementHeader(bank_name="DKB", currency="EUR") + client = CapturingClient(parsed=parsed) + step = GenAIStep(genai_client=client) + req = _make_request(include_provenance=False) + resp = _response_with_segment_index(lines=["hello"]) + resp = await step.process(req, resp) + + assert resp.ix_result.result["bank_name"] == "DKB" + assert resp.ix_result.result["currency"] == "EUR" + assert resp.ix_result.meta_data["model_name"] == "captured-model" + assert resp.ix_result.meta_data["token_usage"]["prompt_tokens"] == 5 + assert resp.ix_result.meta_data["token_usage"]["completion_tokens"] == 7 + + +class TestSystemPromptAssembly: + async def test_citation_instruction_appended_when_provenance_on(self) -> None: + parsed_wrapped: Any = _WrappedResponse( + result=BankStatementHeader(bank_name="DKB", currency="EUR"), + segment_citations=[], + ) + client = CapturingClient(parsed=parsed_wrapped) + step = GenAIStep(genai_client=client) + req = _make_request(include_provenance=True) + resp = _response_with_segment_index(lines=["hello"]) + await step.process(req, resp) + + messages = client.request_kwargs["messages"] # type: ignore[index] + system = messages[0]["content"] + # Use-case system prompt is always there. + assert "extract header metadata" in system + # Citation instruction added. + assert "segment_citations" in system + assert "value_segment_ids" in system + + async def test_citation_instruction_absent_when_provenance_off(self) -> None: + parsed = BankStatementHeader(bank_name="DKB", currency="EUR") + client = CapturingClient(parsed=parsed) + step = GenAIStep(genai_client=client) + req = _make_request(include_provenance=False) + resp = _response_with_segment_index(lines=["hello"]) + await step.process(req, resp) + + messages = client.request_kwargs["messages"] # type: ignore[index] + system = messages[0]["content"] + assert "segment_citations" not in system + + +class TestUserTextFormat: + async def test_tagged_prompt_when_provenance_on(self) -> None: + parsed_wrapped: Any = _WrappedResponse( + result=BankStatementHeader(bank_name="DKB", currency="EUR"), + segment_citations=[], + ) + client = CapturingClient(parsed=parsed_wrapped) + step = GenAIStep(genai_client=client) + req = _make_request(include_provenance=True) + resp = _response_with_segment_index(lines=["alpha line", "beta line"]) + await step.process(req, resp) + + user_content = client.request_kwargs["messages"][1]["content"] # type: ignore[index] + assert "[p1_l0] alpha line" in user_content + assert "[p1_l1] beta line" in user_content + + async def test_plain_prompt_when_provenance_off(self) -> None: + parsed = BankStatementHeader(bank_name="DKB", currency="EUR") + client = CapturingClient(parsed=parsed) + step = GenAIStep(genai_client=client) + req = _make_request(include_provenance=False) + resp = _response_with_segment_index(lines=["alpha line", "beta line"]) + await step.process(req, resp) + + user_content = client.request_kwargs["messages"][1]["content"] # type: ignore[index] + assert "[p1_l0]" not in user_content + assert "alpha line" in user_content + assert "beta line" in user_content + + +class TestResponseSchemaChoice: + async def test_plain_schema_when_provenance_off(self) -> None: + parsed = BankStatementHeader(bank_name="DKB", currency="EUR") + client = CapturingClient(parsed=parsed) + step = GenAIStep(genai_client=client) + req = _make_request(include_provenance=False) + resp = _response_with_segment_index(lines=["hello"]) + await step.process(req, resp) + assert client.response_schema is BankStatementHeader + + async def test_wrapped_schema_when_provenance_on(self) -> None: + parsed_wrapped: Any = _WrappedResponse( + result=BankStatementHeader(bank_name="DKB", currency="EUR"), + segment_citations=[], + ) + client = CapturingClient(parsed=parsed_wrapped) + step = GenAIStep(genai_client=client) + req = _make_request(include_provenance=True) + resp = _response_with_segment_index(lines=["hello"]) + await step.process(req, resp) + schema = client.response_schema + assert schema is not None + field_names = set(schema.model_fields.keys()) + assert field_names == {"result", "segment_citations"} + + +class TestProvenanceMapping: + async def test_provenance_populated_from_citations(self) -> None: + parsed_wrapped: Any = _WrappedResponse( + result=BankStatementHeader(bank_name="DKB", currency="EUR"), + segment_citations=[ + SegmentCitation( + field_path="result.bank_name", + value_segment_ids=["p1_l0"], + context_segment_ids=[], + ), + ], + ) + client = CapturingClient(parsed=parsed_wrapped) + step = GenAIStep(genai_client=client) + req = _make_request(include_provenance=True) + resp = _response_with_segment_index(lines=["DKB"]) + resp = await step.process(req, resp) + + assert isinstance(resp.provenance, ProvenanceData) + fields = resp.provenance.fields + assert "result.bank_name" in fields + fp = fields["result.bank_name"] + assert fp.value == "DKB" + assert len(fp.sources) == 1 + assert fp.sources[0].segment_id == "p1_l0" + # Reliability flags are NOT set here — ReliabilityStep does that. + assert fp.provenance_verified is None + assert fp.text_agreement is None + + +class TestErrorHandling: + async def test_network_error_maps_to_IX_002_000(self) -> None: + err = httpx.ConnectError("refused") + client = FakeGenAIClient( + parsed=BankStatementHeader(bank_name="x", currency="EUR"), + raise_on_call=err, + ) + step = GenAIStep(genai_client=client) + req = _make_request(include_provenance=False) + resp = _response_with_segment_index(lines=["hello"]) + with pytest.raises(IXException) as ei: + await step.process(req, resp) + assert ei.value.code is IXErrorCode.IX_002_000 + + async def test_timeout_maps_to_IX_002_000(self) -> None: + err = httpx.ReadTimeout("slow") + client = FakeGenAIClient( + parsed=BankStatementHeader(bank_name="x", currency="EUR"), + raise_on_call=err, + ) + step = GenAIStep(genai_client=client) + req = _make_request(include_provenance=False) + resp = _response_with_segment_index(lines=["hello"]) + with pytest.raises(IXException) as ei: + await step.process(req, resp) + assert ei.value.code is IXErrorCode.IX_002_000 + + async def test_validation_error_maps_to_IX_002_001(self) -> None: + class _M(BaseModel): + x: int + + try: + _M(x="not-an-int") # type: ignore[arg-type] + except ValidationError as err: + raise_err = err + + client = FakeGenAIClient( + parsed=BankStatementHeader(bank_name="x", currency="EUR"), + raise_on_call=raise_err, + ) + step = GenAIStep(genai_client=client) + req = _make_request(include_provenance=False) + resp = _response_with_segment_index(lines=["hello"]) + with pytest.raises(IXException) as ei: + await step.process(req, resp) + assert ei.value.code is IXErrorCode.IX_002_001 + + +class TestModelSelection: + async def test_request_model_override_wins(self) -> None: + parsed = BankStatementHeader(bank_name="DKB", currency="EUR") + client = CapturingClient(parsed=parsed) + step = GenAIStep(genai_client=client) + req = _make_request(include_provenance=False, model_name="explicit-model") + resp = _response_with_segment_index(lines=["hello"]) + await step.process(req, resp) + assert client.request_kwargs["model"] == "explicit-model" # type: ignore[index] + + async def test_falls_back_to_use_case_default(self) -> None: + parsed = BankStatementHeader(bank_name="DKB", currency="EUR") + client = CapturingClient(parsed=parsed) + step = GenAIStep(genai_client=client) + req = _make_request(include_provenance=False) + resp = _response_with_segment_index(lines=["hello"]) + await step.process(req, resp) + # use-case default is gpt-oss:20b + assert client.request_kwargs["model"] == "gpt-oss:20b" # type: ignore[index] + + +# ---------------------------------------------------------------------------- +# Helpers + + +class _WrappedResponse(BaseModel): + """Stand-in for the runtime-created ProvenanceWrappedResponse.""" + + result: BankStatementHeader + segment_citations: list[SegmentCitation] = []