"""GenAIStep — assemble prompt, call LLM, map provenance (spec §6.3, §7, §9.2). Runs after :class:`~ix.pipeline.ocr_step.OCRStep`. Builds the chat-style ``request_kwargs`` (messages + model name), picks the structured-output schema (plain ``UseCaseResponse`` or a runtime ``ProvenanceWrappedResponse(result=..., segment_citations=...)`` when provenance is on), hands both to the injected :class:`GenAIClient`, and writes the parsed payload onto ``response_ix.ix_result``. When provenance is on, the LLM-emitted ``segment_citations`` flow into :func:`~ix.provenance.map_segment_refs_to_provenance` to build ``response_ix.provenance``. The per-field reliability flags (``provenance_verified`` / ``text_agreement``) stay ``None`` here — they land in :class:`~ix.pipeline.reliability_step.ReliabilityStep`. Failure modes: * Network / timeout / non-2xx surfaced by the client → ``IX_002_000``. * :class:`pydantic.ValidationError` (structured output didn't match the schema) → ``IX_002_001``. """ from __future__ import annotations from typing import Any, cast import httpx from pydantic import BaseModel, Field, ValidationError, create_model from ix.contracts import RequestIX, ResponseIX, SegmentCitation from ix.errors import IXErrorCode, IXException from ix.genai.client import GenAIClient from ix.pipeline.step import Step from ix.provenance import map_segment_refs_to_provenance from ix.segmentation import SegmentIndex # Verbatim from spec §9.2 (core-pipeline spec) — inserted after the # use-case system prompt when provenance is on. _CITATION_INSTRUCTION = ( "For each extracted field, you must also populate the `segment_citations` list.\n" "Each entry maps one field to the document segments that were its source.\n" "Set `field_path` to the dot-separated JSON path of the field " "(e.g. 'result.invoice_number').\n" "Use two separate segment ID lists:\n" "- `value_segment_ids`: segment IDs whose text directly contains the extracted " "value (e.g. ['p1_l4'] for the line containing 'INV-001').\n" "- `context_segment_ids`: segment IDs for surrounding label or anchor text that " "helped you identify the field but does not contain the value itself " "(e.g. ['p1_l3'] for a label like 'Invoice Number:'). Leave empty if there is " "no distinct label.\n" "Only use segment IDs that appear in the document text.\n" "Omit fields for which you cannot identify a source segment." ) class GenAIStep(Step): """LLM extraction + (optional) provenance mapping.""" def __init__(self, genai_client: GenAIClient) -> None: self._client = genai_client async def validate(self, request_ix: RequestIX, response_ix: ResponseIX) -> bool: if request_ix.options.ocr.ocr_only: return False ctx = response_ix.context ocr_text = ( response_ix.ocr_result.result.text if response_ix.ocr_result is not None else None ) texts = list(getattr(ctx, "texts", []) or []) if ctx is not None else [] if not (ocr_text and ocr_text.strip()) and not texts: raise IXException(IXErrorCode.IX_001_000) return True async def process( self, request_ix: RequestIX, response_ix: ResponseIX ) -> ResponseIX: ctx = response_ix.context assert ctx is not None, "SetupStep must populate response_ix.context" use_case_request: Any = getattr(ctx, "use_case_request", None) use_case_response_cls: type[BaseModel] = getattr(ctx, "use_case_response", None) assert use_case_request is not None and use_case_response_cls is not None opts = request_ix.options provenance_on = opts.provenance.include_provenance # 1. System prompt — use-case default + optional citation instruction. system_prompt = use_case_request.system_prompt if provenance_on: system_prompt = f"{system_prompt}\n\n{_CITATION_INSTRUCTION}" # 2. User text — segment-tagged when provenance is on, else plain OCR + texts. user_text = self._build_user_text(response_ix, provenance_on) # 3. Response schema — plain or wrapped. response_schema = self._resolve_response_schema( use_case_response_cls, provenance_on ) # 4. Model selection — request override → use-case default. model_name = ( opts.gen_ai.gen_ai_model_name or getattr(use_case_request, "default_model", None) ) request_kwargs = { "model": model_name, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_text}, ], } # 5. Call the backend, translate errors. try: result = await self._client.invoke( request_kwargs=request_kwargs, response_schema=response_schema, ) except ValidationError as exc: raise IXException( IXErrorCode.IX_002_001, detail=f"{use_case_response_cls.__name__}: {exc}", ) from exc except (httpx.HTTPError, ConnectionError, TimeoutError) as exc: raise IXException( IXErrorCode.IX_002_000, detail=f"{model_name}: {exc.__class__.__name__}: {exc}", ) from exc except IXException: raise # 6. Split parsed output; write result + meta. if provenance_on: wrapped = result.parsed extraction: BaseModel = wrapped.result segment_citations: list[SegmentCitation] = list( getattr(wrapped, "segment_citations", []) or [] ) else: extraction = result.parsed segment_citations = [] response_ix.ix_result.result = extraction.model_dump(mode="json") response_ix.ix_result.meta_data = { "model_name": result.model_name, "token_usage": { "prompt_tokens": result.usage.prompt_tokens, "completion_tokens": result.usage.completion_tokens, }, } # 7. Provenance mapping — only the structural assembly. Reliability # flags get written in ReliabilityStep. if provenance_on: seg_idx = cast(SegmentIndex, getattr(ctx, "segment_index", None)) if seg_idx is None: # No OCR was run (text-only request); skip provenance. response_ix.provenance = None else: response_ix.provenance = map_segment_refs_to_provenance( extraction_result={"result": response_ix.ix_result.result}, segment_citations=segment_citations, segment_index=seg_idx, max_sources_per_field=opts.provenance.max_sources_per_field, min_confidence=0.0, include_bounding_boxes=True, source_type="value_and_context", ) return response_ix def _build_user_text(self, response_ix: ResponseIX, provenance_on: bool) -> str: ctx = response_ix.context assert ctx is not None texts: list[str] = list(getattr(ctx, "texts", []) or []) seg_idx: SegmentIndex | None = getattr(ctx, "segment_index", None) if provenance_on and seg_idx is not None: return seg_idx.to_prompt_text(context_texts=texts) # Plain concat — OCR flat text + any extra paperless-style texts. parts: list[str] = [] ocr_text = ( response_ix.ocr_result.result.text if response_ix.ocr_result is not None else None ) if ocr_text: parts.append(ocr_text) parts.extend(texts) return "\n\n".join(p for p in parts if p) def _resolve_response_schema( self, use_case_response_cls: type[BaseModel], provenance_on: bool, ) -> type[BaseModel]: if not provenance_on: return use_case_response_cls # Dynamic wrapper — one per call is fine; Pydantic caches the # generated JSON schema internally. return create_model( "ProvenanceWrappedResponse", result=(use_case_response_cls, ...), segment_citations=( list[SegmentCitation], Field(default_factory=list), ), ) __all__ = ["GenAIStep"]