Assembles the prompt, picks the structured-output schema, calls the
injected GenAIClient, and maps any emitted segment_citations into
response.provenance. Reliability flags stay None here; ReliabilityStep
fills them in Task 2.7.
- System prompt = use_case.system_prompt + (provenance-on) the verbatim
citation instruction from spec §9.2.
- User text = SegmentIndex.to_prompt_text([p1_l0] style) when provenance
is on, else plain OCR flat text + texts joined.
- Response schema = UseCaseResponse directly, or a runtime
create_model("ProvenanceWrappedResponse", result=(UCR, ...),
segment_citations=(list[SegmentCitation], Field(default_factory=list)))
when provenance is on.
- Model = request override -> use-case default.
- Failure modes: httpx / connection / timeout errors -> IX_002_000;
pydantic.ValidationError -> IX_002_001.
- Writes ix_result.result + ix_result.meta_data (model_name +
token_usage); builds response.provenance via
map_segment_refs_to_provenance when provenance is on.
17 unit tests in tests/unit/test_genai_step.py cover validate
(ocr_only skip, empty -> IX_001_000, text-only, ocr-text path), process
happy path, system-prompt shape with/without citation instruction, user
text tagged vs. plain, response schema plain vs. wrapped, provenance
mapping, error mapping (IX_002_000 + IX_002_001), and model selection
(request override + use-case default).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
216 lines
8.3 KiB
Python
216 lines
8.3 KiB
Python
"""GenAIStep — assemble prompt, call LLM, map provenance (spec §6.3, §7, §9.2).
|
|
|
|
Runs after :class:`~ix.pipeline.ocr_step.OCRStep`. Builds the chat-style
|
|
``request_kwargs`` (messages + model name), picks the structured-output
|
|
schema (plain ``UseCaseResponse`` or a runtime
|
|
``ProvenanceWrappedResponse(result=..., segment_citations=...)`` when
|
|
provenance is on), hands both to the injected :class:`GenAIClient`, and
|
|
writes the parsed payload onto ``response_ix.ix_result``.
|
|
|
|
When provenance is on, the LLM-emitted ``segment_citations`` flow into
|
|
:func:`~ix.provenance.map_segment_refs_to_provenance` to build
|
|
``response_ix.provenance``. The per-field reliability flags
|
|
(``provenance_verified`` / ``text_agreement``) stay ``None`` here — they
|
|
land in :class:`~ix.pipeline.reliability_step.ReliabilityStep`.
|
|
|
|
Failure modes:
|
|
|
|
* Network / timeout / non-2xx surfaced by the client → ``IX_002_000``.
|
|
* :class:`pydantic.ValidationError` (structured output didn't match the
|
|
schema) → ``IX_002_001``.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, cast
|
|
|
|
import httpx
|
|
from pydantic import BaseModel, Field, ValidationError, create_model
|
|
|
|
from ix.contracts import RequestIX, ResponseIX, SegmentCitation
|
|
from ix.errors import IXErrorCode, IXException
|
|
from ix.genai.client import GenAIClient
|
|
from ix.pipeline.step import Step
|
|
from ix.provenance import map_segment_refs_to_provenance
|
|
from ix.segmentation import SegmentIndex
|
|
|
|
# Verbatim from spec §9.2 (core-pipeline spec) — inserted after the
|
|
# use-case system prompt when provenance is on.
|
|
_CITATION_INSTRUCTION = (
|
|
"For each extracted field, you must also populate the `segment_citations` list.\n"
|
|
"Each entry maps one field to the document segments that were its source.\n"
|
|
"Set `field_path` to the dot-separated JSON path of the field "
|
|
"(e.g. 'result.invoice_number').\n"
|
|
"Use two separate segment ID lists:\n"
|
|
"- `value_segment_ids`: segment IDs whose text directly contains the extracted "
|
|
"value (e.g. ['p1_l4'] for the line containing 'INV-001').\n"
|
|
"- `context_segment_ids`: segment IDs for surrounding label or anchor text that "
|
|
"helped you identify the field but does not contain the value itself "
|
|
"(e.g. ['p1_l3'] for a label like 'Invoice Number:'). Leave empty if there is "
|
|
"no distinct label.\n"
|
|
"Only use segment IDs that appear in the document text.\n"
|
|
"Omit fields for which you cannot identify a source segment."
|
|
)
|
|
|
|
|
|
class GenAIStep(Step):
|
|
"""LLM extraction + (optional) provenance mapping."""
|
|
|
|
def __init__(self, genai_client: GenAIClient) -> None:
|
|
self._client = genai_client
|
|
|
|
async def validate(self, request_ix: RequestIX, response_ix: ResponseIX) -> bool:
|
|
if request_ix.options.ocr.ocr_only:
|
|
return False
|
|
|
|
ctx = response_ix.context
|
|
ocr_text = (
|
|
response_ix.ocr_result.result.text
|
|
if response_ix.ocr_result is not None
|
|
else None
|
|
)
|
|
texts = list(getattr(ctx, "texts", []) or []) if ctx is not None else []
|
|
|
|
if not (ocr_text and ocr_text.strip()) and not texts:
|
|
raise IXException(IXErrorCode.IX_001_000)
|
|
return True
|
|
|
|
async def process(
|
|
self, request_ix: RequestIX, response_ix: ResponseIX
|
|
) -> ResponseIX:
|
|
ctx = response_ix.context
|
|
assert ctx is not None, "SetupStep must populate response_ix.context"
|
|
use_case_request: Any = getattr(ctx, "use_case_request", None)
|
|
use_case_response_cls: type[BaseModel] = getattr(ctx, "use_case_response", None)
|
|
assert use_case_request is not None and use_case_response_cls is not None
|
|
|
|
opts = request_ix.options
|
|
provenance_on = opts.provenance.include_provenance
|
|
|
|
# 1. System prompt — use-case default + optional citation instruction.
|
|
system_prompt = use_case_request.system_prompt
|
|
if provenance_on:
|
|
system_prompt = f"{system_prompt}\n\n{_CITATION_INSTRUCTION}"
|
|
|
|
# 2. User text — segment-tagged when provenance is on, else plain OCR + texts.
|
|
user_text = self._build_user_text(response_ix, provenance_on)
|
|
|
|
# 3. Response schema — plain or wrapped.
|
|
response_schema = self._resolve_response_schema(
|
|
use_case_response_cls, provenance_on
|
|
)
|
|
|
|
# 4. Model selection — request override → use-case default.
|
|
model_name = (
|
|
opts.gen_ai.gen_ai_model_name
|
|
or getattr(use_case_request, "default_model", None)
|
|
)
|
|
|
|
request_kwargs = {
|
|
"model": model_name,
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_text},
|
|
],
|
|
}
|
|
|
|
# 5. Call the backend, translate errors.
|
|
try:
|
|
result = await self._client.invoke(
|
|
request_kwargs=request_kwargs,
|
|
response_schema=response_schema,
|
|
)
|
|
except ValidationError as exc:
|
|
raise IXException(
|
|
IXErrorCode.IX_002_001,
|
|
detail=f"{use_case_response_cls.__name__}: {exc}",
|
|
) from exc
|
|
except (httpx.HTTPError, ConnectionError, TimeoutError) as exc:
|
|
raise IXException(
|
|
IXErrorCode.IX_002_000,
|
|
detail=f"{model_name}: {exc.__class__.__name__}: {exc}",
|
|
) from exc
|
|
except IXException:
|
|
raise
|
|
|
|
# 6. Split parsed output; write result + meta.
|
|
if provenance_on:
|
|
wrapped = result.parsed
|
|
extraction: BaseModel = wrapped.result
|
|
segment_citations: list[SegmentCitation] = list(
|
|
getattr(wrapped, "segment_citations", []) or []
|
|
)
|
|
else:
|
|
extraction = result.parsed
|
|
segment_citations = []
|
|
|
|
response_ix.ix_result.result = extraction.model_dump(mode="json")
|
|
response_ix.ix_result.meta_data = {
|
|
"model_name": result.model_name,
|
|
"token_usage": {
|
|
"prompt_tokens": result.usage.prompt_tokens,
|
|
"completion_tokens": result.usage.completion_tokens,
|
|
},
|
|
}
|
|
|
|
# 7. Provenance mapping — only the structural assembly. Reliability
|
|
# flags get written in ReliabilityStep.
|
|
if provenance_on:
|
|
seg_idx = cast(SegmentIndex, getattr(ctx, "segment_index", None))
|
|
if seg_idx is None:
|
|
# No OCR was run (text-only request); skip provenance.
|
|
response_ix.provenance = None
|
|
else:
|
|
response_ix.provenance = map_segment_refs_to_provenance(
|
|
extraction_result={"result": response_ix.ix_result.result},
|
|
segment_citations=segment_citations,
|
|
segment_index=seg_idx,
|
|
max_sources_per_field=opts.provenance.max_sources_per_field,
|
|
min_confidence=0.0,
|
|
include_bounding_boxes=True,
|
|
source_type="value_and_context",
|
|
)
|
|
|
|
return response_ix
|
|
|
|
def _build_user_text(self, response_ix: ResponseIX, provenance_on: bool) -> str:
|
|
ctx = response_ix.context
|
|
assert ctx is not None
|
|
texts: list[str] = list(getattr(ctx, "texts", []) or [])
|
|
seg_idx: SegmentIndex | None = getattr(ctx, "segment_index", None)
|
|
|
|
if provenance_on and seg_idx is not None:
|
|
return seg_idx.to_prompt_text(context_texts=texts)
|
|
|
|
# Plain concat — OCR flat text + any extra paperless-style texts.
|
|
parts: list[str] = []
|
|
ocr_text = (
|
|
response_ix.ocr_result.result.text
|
|
if response_ix.ocr_result is not None
|
|
else None
|
|
)
|
|
if ocr_text:
|
|
parts.append(ocr_text)
|
|
parts.extend(texts)
|
|
return "\n\n".join(p for p in parts if p)
|
|
|
|
def _resolve_response_schema(
|
|
self,
|
|
use_case_response_cls: type[BaseModel],
|
|
provenance_on: bool,
|
|
) -> type[BaseModel]:
|
|
if not provenance_on:
|
|
return use_case_response_cls
|
|
# Dynamic wrapper — one per call is fine; Pydantic caches the
|
|
# generated JSON schema internally.
|
|
return create_model(
|
|
"ProvenanceWrappedResponse",
|
|
result=(use_case_response_cls, ...),
|
|
segment_citations=(
|
|
list[SegmentCitation],
|
|
Field(default_factory=list),
|
|
),
|
|
)
|
|
|
|
|
|
__all__ = ["GenAIStep"]
|