Compare commits
No commits in common. "6d9c239e82a01eb934b2e1f09415c6bcbc2a4261" and "acb2d55ce35dc49fc843623381a90620182caafd" have entirely different histories.
6d9c239e82
...
acb2d55ce3
2 changed files with 0 additions and 594 deletions
|
|
@ -1,216 +0,0 @@
|
||||||
"""GenAIStep — assemble prompt, call LLM, map provenance (spec §6.3, §7, §9.2).
|
|
||||||
|
|
||||||
Runs after :class:`~ix.pipeline.ocr_step.OCRStep`. Builds the chat-style
|
|
||||||
``request_kwargs`` (messages + model name), picks the structured-output
|
|
||||||
schema (plain ``UseCaseResponse`` or a runtime
|
|
||||||
``ProvenanceWrappedResponse(result=..., segment_citations=...)`` when
|
|
||||||
provenance is on), hands both to the injected :class:`GenAIClient`, and
|
|
||||||
writes the parsed payload onto ``response_ix.ix_result``.
|
|
||||||
|
|
||||||
When provenance is on, the LLM-emitted ``segment_citations`` flow into
|
|
||||||
:func:`~ix.provenance.map_segment_refs_to_provenance` to build
|
|
||||||
``response_ix.provenance``. The per-field reliability flags
|
|
||||||
(``provenance_verified`` / ``text_agreement``) stay ``None`` here — they
|
|
||||||
land in :class:`~ix.pipeline.reliability_step.ReliabilityStep`.
|
|
||||||
|
|
||||||
Failure modes:
|
|
||||||
|
|
||||||
* Network / timeout / non-2xx surfaced by the client → ``IX_002_000``.
|
|
||||||
* :class:`pydantic.ValidationError` (structured output didn't match the
|
|
||||||
schema) → ``IX_002_001``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
from pydantic import BaseModel, Field, ValidationError, create_model
|
|
||||||
|
|
||||||
from ix.contracts import RequestIX, ResponseIX, SegmentCitation
|
|
||||||
from ix.errors import IXErrorCode, IXException
|
|
||||||
from ix.genai.client import GenAIClient
|
|
||||||
from ix.pipeline.step import Step
|
|
||||||
from ix.provenance import map_segment_refs_to_provenance
|
|
||||||
from ix.segmentation import SegmentIndex
|
|
||||||
|
|
||||||
# Verbatim from spec §9.2 (core-pipeline spec) — inserted after the
|
|
||||||
# use-case system prompt when provenance is on.
|
|
||||||
_CITATION_INSTRUCTION = (
|
|
||||||
"For each extracted field, you must also populate the `segment_citations` list.\n"
|
|
||||||
"Each entry maps one field to the document segments that were its source.\n"
|
|
||||||
"Set `field_path` to the dot-separated JSON path of the field "
|
|
||||||
"(e.g. 'result.invoice_number').\n"
|
|
||||||
"Use two separate segment ID lists:\n"
|
|
||||||
"- `value_segment_ids`: segment IDs whose text directly contains the extracted "
|
|
||||||
"value (e.g. ['p1_l4'] for the line containing 'INV-001').\n"
|
|
||||||
"- `context_segment_ids`: segment IDs for surrounding label or anchor text that "
|
|
||||||
"helped you identify the field but does not contain the value itself "
|
|
||||||
"(e.g. ['p1_l3'] for a label like 'Invoice Number:'). Leave empty if there is "
|
|
||||||
"no distinct label.\n"
|
|
||||||
"Only use segment IDs that appear in the document text.\n"
|
|
||||||
"Omit fields for which you cannot identify a source segment."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GenAIStep(Step):
|
|
||||||
"""LLM extraction + (optional) provenance mapping."""
|
|
||||||
|
|
||||||
def __init__(self, genai_client: GenAIClient) -> None:
|
|
||||||
self._client = genai_client
|
|
||||||
|
|
||||||
async def validate(self, request_ix: RequestIX, response_ix: ResponseIX) -> bool:
|
|
||||||
if request_ix.options.ocr.ocr_only:
|
|
||||||
return False
|
|
||||||
|
|
||||||
ctx = response_ix.context
|
|
||||||
ocr_text = (
|
|
||||||
response_ix.ocr_result.result.text
|
|
||||||
if response_ix.ocr_result is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
texts = list(getattr(ctx, "texts", []) or []) if ctx is not None else []
|
|
||||||
|
|
||||||
if not (ocr_text and ocr_text.strip()) and not texts:
|
|
||||||
raise IXException(IXErrorCode.IX_001_000)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def process(
|
|
||||||
self, request_ix: RequestIX, response_ix: ResponseIX
|
|
||||||
) -> ResponseIX:
|
|
||||||
ctx = response_ix.context
|
|
||||||
assert ctx is not None, "SetupStep must populate response_ix.context"
|
|
||||||
use_case_request: Any = getattr(ctx, "use_case_request", None)
|
|
||||||
use_case_response_cls: type[BaseModel] = getattr(ctx, "use_case_response", None)
|
|
||||||
assert use_case_request is not None and use_case_response_cls is not None
|
|
||||||
|
|
||||||
opts = request_ix.options
|
|
||||||
provenance_on = opts.provenance.include_provenance
|
|
||||||
|
|
||||||
# 1. System prompt — use-case default + optional citation instruction.
|
|
||||||
system_prompt = use_case_request.system_prompt
|
|
||||||
if provenance_on:
|
|
||||||
system_prompt = f"{system_prompt}\n\n{_CITATION_INSTRUCTION}"
|
|
||||||
|
|
||||||
# 2. User text — segment-tagged when provenance is on, else plain OCR + texts.
|
|
||||||
user_text = self._build_user_text(response_ix, provenance_on)
|
|
||||||
|
|
||||||
# 3. Response schema — plain or wrapped.
|
|
||||||
response_schema = self._resolve_response_schema(
|
|
||||||
use_case_response_cls, provenance_on
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Model selection — request override → use-case default.
|
|
||||||
model_name = (
|
|
||||||
opts.gen_ai.gen_ai_model_name
|
|
||||||
or getattr(use_case_request, "default_model", None)
|
|
||||||
)
|
|
||||||
|
|
||||||
request_kwargs = {
|
|
||||||
"model": model_name,
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": system_prompt},
|
|
||||||
{"role": "user", "content": user_text},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
# 5. Call the backend, translate errors.
|
|
||||||
try:
|
|
||||||
result = await self._client.invoke(
|
|
||||||
request_kwargs=request_kwargs,
|
|
||||||
response_schema=response_schema,
|
|
||||||
)
|
|
||||||
except ValidationError as exc:
|
|
||||||
raise IXException(
|
|
||||||
IXErrorCode.IX_002_001,
|
|
||||||
detail=f"{use_case_response_cls.__name__}: {exc}",
|
|
||||||
) from exc
|
|
||||||
except (httpx.HTTPError, ConnectionError, TimeoutError) as exc:
|
|
||||||
raise IXException(
|
|
||||||
IXErrorCode.IX_002_000,
|
|
||||||
detail=f"{model_name}: {exc.__class__.__name__}: {exc}",
|
|
||||||
) from exc
|
|
||||||
except IXException:
|
|
||||||
raise
|
|
||||||
|
|
||||||
# 6. Split parsed output; write result + meta.
|
|
||||||
if provenance_on:
|
|
||||||
wrapped = result.parsed
|
|
||||||
extraction: BaseModel = wrapped.result
|
|
||||||
segment_citations: list[SegmentCitation] = list(
|
|
||||||
getattr(wrapped, "segment_citations", []) or []
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
extraction = result.parsed
|
|
||||||
segment_citations = []
|
|
||||||
|
|
||||||
response_ix.ix_result.result = extraction.model_dump(mode="json")
|
|
||||||
response_ix.ix_result.meta_data = {
|
|
||||||
"model_name": result.model_name,
|
|
||||||
"token_usage": {
|
|
||||||
"prompt_tokens": result.usage.prompt_tokens,
|
|
||||||
"completion_tokens": result.usage.completion_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# 7. Provenance mapping — only the structural assembly. Reliability
|
|
||||||
# flags get written in ReliabilityStep.
|
|
||||||
if provenance_on:
|
|
||||||
seg_idx = cast(SegmentIndex, getattr(ctx, "segment_index", None))
|
|
||||||
if seg_idx is None:
|
|
||||||
# No OCR was run (text-only request); skip provenance.
|
|
||||||
response_ix.provenance = None
|
|
||||||
else:
|
|
||||||
response_ix.provenance = map_segment_refs_to_provenance(
|
|
||||||
extraction_result={"result": response_ix.ix_result.result},
|
|
||||||
segment_citations=segment_citations,
|
|
||||||
segment_index=seg_idx,
|
|
||||||
max_sources_per_field=opts.provenance.max_sources_per_field,
|
|
||||||
min_confidence=0.0,
|
|
||||||
include_bounding_boxes=True,
|
|
||||||
source_type="value_and_context",
|
|
||||||
)
|
|
||||||
|
|
||||||
return response_ix
|
|
||||||
|
|
||||||
def _build_user_text(self, response_ix: ResponseIX, provenance_on: bool) -> str:
|
|
||||||
ctx = response_ix.context
|
|
||||||
assert ctx is not None
|
|
||||||
texts: list[str] = list(getattr(ctx, "texts", []) or [])
|
|
||||||
seg_idx: SegmentIndex | None = getattr(ctx, "segment_index", None)
|
|
||||||
|
|
||||||
if provenance_on and seg_idx is not None:
|
|
||||||
return seg_idx.to_prompt_text(context_texts=texts)
|
|
||||||
|
|
||||||
# Plain concat — OCR flat text + any extra paperless-style texts.
|
|
||||||
parts: list[str] = []
|
|
||||||
ocr_text = (
|
|
||||||
response_ix.ocr_result.result.text
|
|
||||||
if response_ix.ocr_result is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
if ocr_text:
|
|
||||||
parts.append(ocr_text)
|
|
||||||
parts.extend(texts)
|
|
||||||
return "\n\n".join(p for p in parts if p)
|
|
||||||
|
|
||||||
def _resolve_response_schema(
|
|
||||||
self,
|
|
||||||
use_case_response_cls: type[BaseModel],
|
|
||||||
provenance_on: bool,
|
|
||||||
) -> type[BaseModel]:
|
|
||||||
if not provenance_on:
|
|
||||||
return use_case_response_cls
|
|
||||||
# Dynamic wrapper — one per call is fine; Pydantic caches the
|
|
||||||
# generated JSON schema internally.
|
|
||||||
return create_model(
|
|
||||||
"ProvenanceWrappedResponse",
|
|
||||||
result=(use_case_response_cls, ...),
|
|
||||||
segment_citations=(
|
|
||||||
list[SegmentCitation],
|
|
||||||
Field(default_factory=list),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["GenAIStep"]
|
|
||||||
|
|
@ -1,378 +0,0 @@
|
||||||
"""Tests for :class:`ix.pipeline.genai_step.GenAIStep` (spec §6.3, §7, §9.2)."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import pytest
|
|
||||||
from pydantic import BaseModel, ValidationError
|
|
||||||
|
|
||||||
from ix.contracts import (
|
|
||||||
Context,
|
|
||||||
GenAIOptions,
|
|
||||||
Line,
|
|
||||||
OCRDetails,
|
|
||||||
OCROptions,
|
|
||||||
OCRResult,
|
|
||||||
Options,
|
|
||||||
Page,
|
|
||||||
ProvenanceData,
|
|
||||||
ProvenanceOptions,
|
|
||||||
RequestIX,
|
|
||||||
ResponseIX,
|
|
||||||
SegmentCitation,
|
|
||||||
)
|
|
||||||
from ix.contracts.response import _InternalContext
|
|
||||||
from ix.errors import IXErrorCode, IXException
|
|
||||||
from ix.genai import FakeGenAIClient, GenAIInvocationResult, GenAIUsage
|
|
||||||
from ix.pipeline.genai_step import GenAIStep
|
|
||||||
from ix.segmentation import PageMetadata, SegmentIndex
|
|
||||||
from ix.use_cases.bank_statement_header import BankStatementHeader
|
|
||||||
from ix.use_cases.bank_statement_header import Request as BankReq
|
|
||||||
|
|
||||||
|
|
||||||
def _make_request(
|
|
||||||
*,
|
|
||||||
use_ocr: bool = True,
|
|
||||||
ocr_only: bool = False,
|
|
||||||
include_provenance: bool = True,
|
|
||||||
model_name: str | None = None,
|
|
||||||
) -> RequestIX:
|
|
||||||
return RequestIX(
|
|
||||||
use_case="bank_statement_header",
|
|
||||||
ix_client_id="test",
|
|
||||||
request_id="r-1",
|
|
||||||
context=Context(files=[], texts=[]),
|
|
||||||
options=Options(
|
|
||||||
ocr=OCROptions(use_ocr=use_ocr, ocr_only=ocr_only),
|
|
||||||
gen_ai=GenAIOptions(gen_ai_model_name=model_name),
|
|
||||||
provenance=ProvenanceOptions(
|
|
||||||
include_provenance=include_provenance,
|
|
||||||
max_sources_per_field=5,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _ocr_with_lines(lines: list[str]) -> OCRResult:
|
|
||||||
return OCRResult(
|
|
||||||
result=OCRDetails(
|
|
||||||
text="\n".join(lines),
|
|
||||||
pages=[
|
|
||||||
Page(
|
|
||||||
page_no=1,
|
|
||||||
width=100.0,
|
|
||||||
height=200.0,
|
|
||||||
lines=[
|
|
||||||
Line(text=t, bounding_box=[0, i * 10, 10, i * 10, 10, i * 10 + 5, 0, i * 10 + 5])
|
|
||||||
for i, t in enumerate(lines)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _response_with_segment_index(
|
|
||||||
lines: list[str], texts: list[str] | None = None
|
|
||||||
) -> ResponseIX:
|
|
||||||
ocr = _ocr_with_lines(lines)
|
|
||||||
resp = ResponseIX(ocr_result=ocr)
|
|
||||||
seg_idx = SegmentIndex.build(
|
|
||||||
ocr_result=ocr,
|
|
||||||
granularity="line",
|
|
||||||
pages_metadata=[PageMetadata(file_index=0)],
|
|
||||||
)
|
|
||||||
resp.context = _InternalContext(
|
|
||||||
use_case_request=BankReq(),
|
|
||||||
use_case_response=BankStatementHeader,
|
|
||||||
segment_index=seg_idx,
|
|
||||||
texts=texts or [],
|
|
||||||
pages=ocr.result.pages,
|
|
||||||
page_metadata=[PageMetadata(file_index=0)],
|
|
||||||
)
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
class CapturingClient:
|
|
||||||
"""Records the request_kwargs + response_schema handed to invoke()."""
|
|
||||||
|
|
||||||
def __init__(self, parsed: Any) -> None:
|
|
||||||
self._parsed = parsed
|
|
||||||
self.request_kwargs: dict[str, Any] | None = None
|
|
||||||
self.response_schema: type[BaseModel] | None = None
|
|
||||||
|
|
||||||
async def invoke(
|
|
||||||
self,
|
|
||||||
request_kwargs: dict[str, Any],
|
|
||||||
response_schema: type[BaseModel],
|
|
||||||
) -> GenAIInvocationResult:
|
|
||||||
self.request_kwargs = request_kwargs
|
|
||||||
self.response_schema = response_schema
|
|
||||||
return GenAIInvocationResult(
|
|
||||||
parsed=self._parsed,
|
|
||||||
usage=GenAIUsage(prompt_tokens=5, completion_tokens=7),
|
|
||||||
model_name="captured-model",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestValidate:
|
|
||||||
async def test_ocr_only_skips(self) -> None:
|
|
||||||
step = GenAIStep(
|
|
||||||
genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR"))
|
|
||||||
)
|
|
||||||
req = _make_request(ocr_only=True)
|
|
||||||
resp = _response_with_segment_index(lines=["hello"])
|
|
||||||
assert await step.validate(req, resp) is False
|
|
||||||
|
|
||||||
async def test_empty_context_raises_IX_001_000(self) -> None:
|
|
||||||
step = GenAIStep(
|
|
||||||
genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR"))
|
|
||||||
)
|
|
||||||
req = _make_request()
|
|
||||||
resp = ResponseIX(ocr_result=OCRResult(result=OCRDetails(text="")))
|
|
||||||
resp.context = _InternalContext(
|
|
||||||
use_case_request=BankReq(),
|
|
||||||
use_case_response=BankStatementHeader,
|
|
||||||
texts=[],
|
|
||||||
)
|
|
||||||
with pytest.raises(IXException) as ei:
|
|
||||||
await step.validate(req, resp)
|
|
||||||
assert ei.value.code is IXErrorCode.IX_001_000
|
|
||||||
|
|
||||||
async def test_runs_when_texts_only(self) -> None:
|
|
||||||
step = GenAIStep(
|
|
||||||
genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR"))
|
|
||||||
)
|
|
||||||
req = _make_request()
|
|
||||||
resp = ResponseIX(ocr_result=OCRResult(result=OCRDetails(text="")))
|
|
||||||
resp.context = _InternalContext(
|
|
||||||
use_case_request=BankReq(),
|
|
||||||
use_case_response=BankStatementHeader,
|
|
||||||
texts=["some paperless text"],
|
|
||||||
)
|
|
||||||
assert await step.validate(req, resp) is True
|
|
||||||
|
|
||||||
async def test_runs_when_ocr_text_present(self) -> None:
|
|
||||||
step = GenAIStep(
|
|
||||||
genai_client=FakeGenAIClient(parsed=BankStatementHeader(bank_name="x", currency="EUR"))
|
|
||||||
)
|
|
||||||
req = _make_request()
|
|
||||||
resp = _response_with_segment_index(lines=["hello"])
|
|
||||||
assert await step.validate(req, resp) is True
|
|
||||||
|
|
||||||
|
|
||||||
class TestProcessBasic:
|
|
||||||
async def test_writes_ix_result_and_meta(self) -> None:
|
|
||||||
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
|
|
||||||
client = CapturingClient(parsed=parsed)
|
|
||||||
step = GenAIStep(genai_client=client)
|
|
||||||
req = _make_request(include_provenance=False)
|
|
||||||
resp = _response_with_segment_index(lines=["hello"])
|
|
||||||
resp = await step.process(req, resp)
|
|
||||||
|
|
||||||
assert resp.ix_result.result["bank_name"] == "DKB"
|
|
||||||
assert resp.ix_result.result["currency"] == "EUR"
|
|
||||||
assert resp.ix_result.meta_data["model_name"] == "captured-model"
|
|
||||||
assert resp.ix_result.meta_data["token_usage"]["prompt_tokens"] == 5
|
|
||||||
assert resp.ix_result.meta_data["token_usage"]["completion_tokens"] == 7
|
|
||||||
|
|
||||||
|
|
||||||
class TestSystemPromptAssembly:
|
|
||||||
async def test_citation_instruction_appended_when_provenance_on(self) -> None:
|
|
||||||
parsed_wrapped: Any = _WrappedResponse(
|
|
||||||
result=BankStatementHeader(bank_name="DKB", currency="EUR"),
|
|
||||||
segment_citations=[],
|
|
||||||
)
|
|
||||||
client = CapturingClient(parsed=parsed_wrapped)
|
|
||||||
step = GenAIStep(genai_client=client)
|
|
||||||
req = _make_request(include_provenance=True)
|
|
||||||
resp = _response_with_segment_index(lines=["hello"])
|
|
||||||
await step.process(req, resp)
|
|
||||||
|
|
||||||
messages = client.request_kwargs["messages"] # type: ignore[index]
|
|
||||||
system = messages[0]["content"]
|
|
||||||
# Use-case system prompt is always there.
|
|
||||||
assert "extract header metadata" in system
|
|
||||||
# Citation instruction added.
|
|
||||||
assert "segment_citations" in system
|
|
||||||
assert "value_segment_ids" in system
|
|
||||||
|
|
||||||
async def test_citation_instruction_absent_when_provenance_off(self) -> None:
|
|
||||||
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
|
|
||||||
client = CapturingClient(parsed=parsed)
|
|
||||||
step = GenAIStep(genai_client=client)
|
|
||||||
req = _make_request(include_provenance=False)
|
|
||||||
resp = _response_with_segment_index(lines=["hello"])
|
|
||||||
await step.process(req, resp)
|
|
||||||
|
|
||||||
messages = client.request_kwargs["messages"] # type: ignore[index]
|
|
||||||
system = messages[0]["content"]
|
|
||||||
assert "segment_citations" not in system
|
|
||||||
|
|
||||||
|
|
||||||
class TestUserTextFormat:
|
|
||||||
async def test_tagged_prompt_when_provenance_on(self) -> None:
|
|
||||||
parsed_wrapped: Any = _WrappedResponse(
|
|
||||||
result=BankStatementHeader(bank_name="DKB", currency="EUR"),
|
|
||||||
segment_citations=[],
|
|
||||||
)
|
|
||||||
client = CapturingClient(parsed=parsed_wrapped)
|
|
||||||
step = GenAIStep(genai_client=client)
|
|
||||||
req = _make_request(include_provenance=True)
|
|
||||||
resp = _response_with_segment_index(lines=["alpha line", "beta line"])
|
|
||||||
await step.process(req, resp)
|
|
||||||
|
|
||||||
user_content = client.request_kwargs["messages"][1]["content"] # type: ignore[index]
|
|
||||||
assert "[p1_l0] alpha line" in user_content
|
|
||||||
assert "[p1_l1] beta line" in user_content
|
|
||||||
|
|
||||||
async def test_plain_prompt_when_provenance_off(self) -> None:
|
|
||||||
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
|
|
||||||
client = CapturingClient(parsed=parsed)
|
|
||||||
step = GenAIStep(genai_client=client)
|
|
||||||
req = _make_request(include_provenance=False)
|
|
||||||
resp = _response_with_segment_index(lines=["alpha line", "beta line"])
|
|
||||||
await step.process(req, resp)
|
|
||||||
|
|
||||||
user_content = client.request_kwargs["messages"][1]["content"] # type: ignore[index]
|
|
||||||
assert "[p1_l0]" not in user_content
|
|
||||||
assert "alpha line" in user_content
|
|
||||||
assert "beta line" in user_content
|
|
||||||
|
|
||||||
|
|
||||||
class TestResponseSchemaChoice:
|
|
||||||
async def test_plain_schema_when_provenance_off(self) -> None:
|
|
||||||
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
|
|
||||||
client = CapturingClient(parsed=parsed)
|
|
||||||
step = GenAIStep(genai_client=client)
|
|
||||||
req = _make_request(include_provenance=False)
|
|
||||||
resp = _response_with_segment_index(lines=["hello"])
|
|
||||||
await step.process(req, resp)
|
|
||||||
assert client.response_schema is BankStatementHeader
|
|
||||||
|
|
||||||
async def test_wrapped_schema_when_provenance_on(self) -> None:
|
|
||||||
parsed_wrapped: Any = _WrappedResponse(
|
|
||||||
result=BankStatementHeader(bank_name="DKB", currency="EUR"),
|
|
||||||
segment_citations=[],
|
|
||||||
)
|
|
||||||
client = CapturingClient(parsed=parsed_wrapped)
|
|
||||||
step = GenAIStep(genai_client=client)
|
|
||||||
req = _make_request(include_provenance=True)
|
|
||||||
resp = _response_with_segment_index(lines=["hello"])
|
|
||||||
await step.process(req, resp)
|
|
||||||
schema = client.response_schema
|
|
||||||
assert schema is not None
|
|
||||||
field_names = set(schema.model_fields.keys())
|
|
||||||
assert field_names == {"result", "segment_citations"}
|
|
||||||
|
|
||||||
|
|
||||||
class TestProvenanceMapping:
|
|
||||||
async def test_provenance_populated_from_citations(self) -> None:
|
|
||||||
parsed_wrapped: Any = _WrappedResponse(
|
|
||||||
result=BankStatementHeader(bank_name="DKB", currency="EUR"),
|
|
||||||
segment_citations=[
|
|
||||||
SegmentCitation(
|
|
||||||
field_path="result.bank_name",
|
|
||||||
value_segment_ids=["p1_l0"],
|
|
||||||
context_segment_ids=[],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
client = CapturingClient(parsed=parsed_wrapped)
|
|
||||||
step = GenAIStep(genai_client=client)
|
|
||||||
req = _make_request(include_provenance=True)
|
|
||||||
resp = _response_with_segment_index(lines=["DKB"])
|
|
||||||
resp = await step.process(req, resp)
|
|
||||||
|
|
||||||
assert isinstance(resp.provenance, ProvenanceData)
|
|
||||||
fields = resp.provenance.fields
|
|
||||||
assert "result.bank_name" in fields
|
|
||||||
fp = fields["result.bank_name"]
|
|
||||||
assert fp.value == "DKB"
|
|
||||||
assert len(fp.sources) == 1
|
|
||||||
assert fp.sources[0].segment_id == "p1_l0"
|
|
||||||
# Reliability flags are NOT set here — ReliabilityStep does that.
|
|
||||||
assert fp.provenance_verified is None
|
|
||||||
assert fp.text_agreement is None
|
|
||||||
|
|
||||||
|
|
||||||
class TestErrorHandling:
|
|
||||||
async def test_network_error_maps_to_IX_002_000(self) -> None:
|
|
||||||
err = httpx.ConnectError("refused")
|
|
||||||
client = FakeGenAIClient(
|
|
||||||
parsed=BankStatementHeader(bank_name="x", currency="EUR"),
|
|
||||||
raise_on_call=err,
|
|
||||||
)
|
|
||||||
step = GenAIStep(genai_client=client)
|
|
||||||
req = _make_request(include_provenance=False)
|
|
||||||
resp = _response_with_segment_index(lines=["hello"])
|
|
||||||
with pytest.raises(IXException) as ei:
|
|
||||||
await step.process(req, resp)
|
|
||||||
assert ei.value.code is IXErrorCode.IX_002_000
|
|
||||||
|
|
||||||
async def test_timeout_maps_to_IX_002_000(self) -> None:
|
|
||||||
err = httpx.ReadTimeout("slow")
|
|
||||||
client = FakeGenAIClient(
|
|
||||||
parsed=BankStatementHeader(bank_name="x", currency="EUR"),
|
|
||||||
raise_on_call=err,
|
|
||||||
)
|
|
||||||
step = GenAIStep(genai_client=client)
|
|
||||||
req = _make_request(include_provenance=False)
|
|
||||||
resp = _response_with_segment_index(lines=["hello"])
|
|
||||||
with pytest.raises(IXException) as ei:
|
|
||||||
await step.process(req, resp)
|
|
||||||
assert ei.value.code is IXErrorCode.IX_002_000
|
|
||||||
|
|
||||||
async def test_validation_error_maps_to_IX_002_001(self) -> None:
|
|
||||||
class _M(BaseModel):
|
|
||||||
x: int
|
|
||||||
|
|
||||||
try:
|
|
||||||
_M(x="not-an-int") # type: ignore[arg-type]
|
|
||||||
except ValidationError as err:
|
|
||||||
raise_err = err
|
|
||||||
|
|
||||||
client = FakeGenAIClient(
|
|
||||||
parsed=BankStatementHeader(bank_name="x", currency="EUR"),
|
|
||||||
raise_on_call=raise_err,
|
|
||||||
)
|
|
||||||
step = GenAIStep(genai_client=client)
|
|
||||||
req = _make_request(include_provenance=False)
|
|
||||||
resp = _response_with_segment_index(lines=["hello"])
|
|
||||||
with pytest.raises(IXException) as ei:
|
|
||||||
await step.process(req, resp)
|
|
||||||
assert ei.value.code is IXErrorCode.IX_002_001
|
|
||||||
|
|
||||||
|
|
||||||
class TestModelSelection:
|
|
||||||
async def test_request_model_override_wins(self) -> None:
|
|
||||||
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
|
|
||||||
client = CapturingClient(parsed=parsed)
|
|
||||||
step = GenAIStep(genai_client=client)
|
|
||||||
req = _make_request(include_provenance=False, model_name="explicit-model")
|
|
||||||
resp = _response_with_segment_index(lines=["hello"])
|
|
||||||
await step.process(req, resp)
|
|
||||||
assert client.request_kwargs["model"] == "explicit-model" # type: ignore[index]
|
|
||||||
|
|
||||||
async def test_falls_back_to_use_case_default(self) -> None:
|
|
||||||
parsed = BankStatementHeader(bank_name="DKB", currency="EUR")
|
|
||||||
client = CapturingClient(parsed=parsed)
|
|
||||||
step = GenAIStep(genai_client=client)
|
|
||||||
req = _make_request(include_provenance=False)
|
|
||||||
resp = _response_with_segment_index(lines=["hello"])
|
|
||||||
await step.process(req, resp)
|
|
||||||
# use-case default is gpt-oss:20b
|
|
||||||
assert client.request_kwargs["model"] == "gpt-oss:20b" # type: ignore[index]
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
|
|
||||||
|
|
||||||
class _WrappedResponse(BaseModel):
|
|
||||||
"""Stand-in for the runtime-created ProvenanceWrappedResponse."""
|
|
||||||
|
|
||||||
result: BankStatementHeader
|
|
||||||
segment_citations: list[SegmentCitation] = []
|
|
||||||
Loading…
Reference in a new issue