diff --git a/src/ix/contracts/__init__.py b/src/ix/contracts/__init__.py index 2f57a45..9c92bc1 100644 --- a/src/ix/contracts/__init__.py +++ b/src/ix/contracts/__init__.py @@ -6,6 +6,14 @@ Re-exports the public symbols from sibling modules so call sites can write from __future__ import annotations +from ix.contracts.job import CallbackStatus, Job, JobStatus +from ix.contracts.provenance import ( + BoundingBox, + ExtractionSource, + FieldProvenance, + ProvenanceData, + SegmentCitation, +) from ix.contracts.request import ( Context, FileRef, @@ -15,13 +23,37 @@ from ix.contracts.request import ( ProvenanceOptions, RequestIX, ) +from ix.contracts.response import ( + IXResult, + Line, + Metadata, + OCRDetails, + OCRResult, + Page, + ResponseIX, +) __all__ = [ + "BoundingBox", + "CallbackStatus", "Context", + "ExtractionSource", + "FieldProvenance", "FileRef", "GenAIOptions", + "IXResult", + "Job", + "JobStatus", + "Line", + "Metadata", + "OCRDetails", "OCROptions", + "OCRResult", "Options", + "Page", + "ProvenanceData", "ProvenanceOptions", "RequestIX", + "ResponseIX", + "SegmentCitation", ] diff --git a/src/ix/contracts/job.py b/src/ix/contracts/job.py new file mode 100644 index 0000000..92addf6 --- /dev/null +++ b/src/ix/contracts/job.py @@ -0,0 +1,46 @@ +"""Job envelope stored in ``ix_jobs`` and returned by REST. + +Mirrors spec §3 ("Job envelope") and §4 ("Job store"). The lifecycle +enum is a ``Literal`` so Pydantic rejects unknown values at parse time. +``callback_status`` is nullable until the worker attempts delivery (or +skips delivery when there's no ``callback_url``). +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Literal +from uuid import UUID + +from pydantic import BaseModel, ConfigDict + +from ix.contracts.request import RequestIX +from ix.contracts.response import ResponseIX + +JobStatus = Literal["pending", "running", "done", "error"] +CallbackStatus = Literal["pending", "delivered", "failed"] + + +class Job(BaseModel): + """Row of ``ix_jobs`` + its request/response bodies. + + The invariant ``status='done' iff response.error is None`` is enforced by + the worker, not here — callers occasionally hydrate a stale or in-flight + row and we don't want the Pydantic validator to reject it. + """ + + model_config = ConfigDict(extra="forbid") + + job_id: UUID + ix_id: str + client_id: str + request_id: str + status: JobStatus + request: RequestIX + response: ResponseIX | None = None + callback_url: str | None = None + callback_status: CallbackStatus | None = None + attempts: int = 0 + created_at: datetime + started_at: datetime | None = None + finished_at: datetime | None = None diff --git a/src/ix/contracts/provenance.py b/src/ix/contracts/provenance.py new file mode 100644 index 0000000..db65e88 --- /dev/null +++ b/src/ix/contracts/provenance.py @@ -0,0 +1,89 @@ +"""Provenance contracts — per-field segment citations + reliability flags. + +These models represent *outputs* attached to :class:`~ix.contracts.response.ResponseIX`. +The MVP adds two new fields to :class:`FieldProvenance` beyond the reference +spec: ``provenance_verified`` and ``text_agreement``. Both are written by the +new :class:`ReliabilityStep` and are the primary reliability signals that +callers (mammon first) use to decide trust. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class BoundingBox(BaseModel): + """Eight-coordinate polygon, normalised to 0-1 against page dimensions. + + ``coordinates`` order: ``[x1, y1, x2, y2, x3, y3, x4, y4]`` — top-left, + top-right, bottom-right, bottom-left (same as Surya's polygon output). + """ + + model_config = ConfigDict(extra="forbid") + + coordinates: list[float] + + +class SegmentCitation(BaseModel): + """LLM-emitted citation: one field → the segments it came from. + + This is part of the dynamic ``ProvenanceWrappedResponse`` the GenAIStep + asks the model to return when provenance is on (see spec §9.2). + """ + + model_config = ConfigDict(extra="forbid") + + field_path: str + value_segment_ids: list[str] = Field(default_factory=list) + context_segment_ids: list[str] = Field(default_factory=list) + + +class ExtractionSource(BaseModel): + """One resolved source for a field — maps a segment ID to its on-page anchor.""" + + model_config = ConfigDict(extra="forbid") + + page_number: int + file_index: int | None = None + bounding_box: BoundingBox | None = None + text_snippet: str + relevance_score: float = 1.0 + segment_id: str | None = None + + +class FieldProvenance(BaseModel): + """Per-field provenance + MVP reliability flags. + + ``provenance_verified``: True when at least one cited segment's + ``text_snippet`` agrees with the extracted value after normalisation; + False when every cite disagrees; None when the field type makes the check + meaningless (``Literal``, ``None``/unset). + + ``text_agreement``: True when the value also appears in + ``RequestIX.context.texts`` after normalisation; False when the texts + disagree; None when no texts were provided, or when the short-value skip + rule applies, or when the type is ``Literal``/``None``. + """ + + model_config = ConfigDict(extra="forbid") + + field_name: str + field_path: str + value: Any = None + sources: list[ExtractionSource] = Field(default_factory=list) + confidence: float | None = None + provenance_verified: bool | None = None + text_agreement: bool | None = None + + +class ProvenanceData(BaseModel): + """Aggregate provenance payload attached to :class:`ResponseIX.provenance`.""" + + model_config = ConfigDict(extra="forbid") + + fields: dict[str, FieldProvenance] = Field(default_factory=dict) + quality_metrics: dict[str, Any] = Field(default_factory=dict) + segment_count: int | None = None + granularity: str | None = None diff --git a/src/ix/contracts/response.py b/src/ix/contracts/response.py new file mode 100644 index 0000000..65329fc --- /dev/null +++ b/src/ix/contracts/response.py @@ -0,0 +1,124 @@ +"""Outgoing response contracts — :class:`ResponseIX` + nested result structures. + +Mirrors MVP spec §3 / §9.3. The only subtle piece is ``ResponseIX.context``: +it is an *internal* mutable accumulator used by pipeline steps (pages, files, +texts, use-case classes, segment index) and MUST NOT be serialised to the +caller. We enforce this with ``Field(exclude=True)`` on an opaque +:class:`_InternalContext` sub-model. + +Strictness note: unlike :class:`RequestIX`, ResponseIX tolerates extra keys +via ``extra="allow"`` on the internal-context carrier so worker-side code can +stash arbitrary step-level state without growing the schema. Public +response fields still use ``extra="forbid"``. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from ix.contracts.provenance import ProvenanceData + + +class Line(BaseModel): + """One OCR line with its raw 8-coord polygon.""" + + model_config = ConfigDict(extra="forbid") + + text: str | None = None + bounding_box: list[float] = Field(default_factory=list) + + +class Page(BaseModel): + """One OCR page. ``width``/``height`` are in points (pixels for raster images).""" + + model_config = ConfigDict(extra="forbid") + + page_no: int + width: float + height: float + angle: float = 0.0 + unit: str | None = None + lines: list[Line] = Field(default_factory=list) + + +class OCRDetails(BaseModel): + """OCR structural output.""" + + model_config = ConfigDict(extra="forbid") + + text: str | None = None + pages: list[Page] = Field(default_factory=list) + + +class OCRResult(BaseModel): + """Wraps :class:`OCRDetails` with arbitrary adapter metadata.""" + + model_config = ConfigDict(extra="forbid") + + result: OCRDetails = Field(default_factory=OCRDetails) + meta_data: dict[str, Any] = Field(default_factory=dict) + + +class IXResult(BaseModel): + """LLM extraction payload + usage/model metadata.""" + + model_config = ConfigDict(extra="forbid") + + result: dict[str, Any] = Field(default_factory=dict) + result_confidence: dict[str, Any] = Field(default_factory=dict) + meta_data: dict[str, Any] = Field(default_factory=dict) + + +class Metadata(BaseModel): + """Pipeline-level telemetry — populated by the orchestrator.""" + + model_config = ConfigDict(extra="forbid") + + timings: list[dict[str, Any]] = Field(default_factory=list) + processed_by: str | None = None + use_case_truncated: bool | None = None + + +class _InternalContext(BaseModel): + """Internal mutable accumulator — NEVER serialised. + + Holds per-request state the pipeline steps pass to each other: downloaded + file handles, flat page list, extracted text scratchpad, the loaded + use-case ``Request`` / response-schema classes, and the built segment + index. Kept ``extra="allow"`` so adapters/steps can stash arbitrary state + without churning this contract. Always excluded from ``model_dump``. + """ + + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + + pages: list[Any] = Field(default_factory=list) + files: list[Any] = Field(default_factory=list) + texts: list[str] = Field(default_factory=list) + use_case_request: Any = None + use_case_response: Any = None + segment_index: Any = None + + +class ResponseIX(BaseModel): + """Top-level response shape returned through the job store. + + ``context`` is internal-only — ``Field(exclude=True)`` keeps it out of + serialised JSON. Callers see ``use_case`` … ``metadata`` and nothing else. + """ + + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + use_case: str | None = None + use_case_name: str | None = None + ix_client_id: str | None = None + request_id: str | None = None + ix_id: str | None = None + error: str | None = None + warning: list[str] = Field(default_factory=list) + ix_result: IXResult = Field(default_factory=IXResult) + ocr_result: OCRResult = Field(default_factory=OCRResult) + provenance: ProvenanceData | None = None + metadata: Metadata = Field(default_factory=Metadata) + context: _InternalContext | None = Field(default=None, exclude=True) diff --git a/tests/unit/test_contracts.py b/tests/unit/test_contracts.py index d1b94e0..8f5d1a1 100644 --- a/tests/unit/test_contracts.py +++ b/tests/unit/test_contracts.py @@ -3,18 +3,33 @@ from __future__ import annotations import json +from datetime import UTC, datetime +from uuid import uuid4 import pytest from pydantic import ValidationError from ix.contracts import ( + BoundingBox, Context, + ExtractionSource, + FieldProvenance, FileRef, GenAIOptions, + IXResult, + Job, + Line, + Metadata, + OCRDetails, OCROptions, + OCRResult, Options, + Page, + ProvenanceData, ProvenanceOptions, RequestIX, + ResponseIX, + SegmentCitation, ) @@ -166,3 +181,210 @@ class TestRequestIX: def test_missing_required_fields(self) -> None: with pytest.raises(ValidationError): RequestIX.model_validate({"use_case": "x"}) + + +class TestOCRResult: + def test_minimal_defaults(self) -> None: + result = OCRResult() + assert result.result.text is None + assert result.result.pages == [] + assert result.meta_data == {} + + def test_full_page_roundtrip(self) -> None: + page = Page( + page_no=1, + width=612.0, + height=792.0, + lines=[Line(text="hello", bounding_box=[0, 0, 10, 0, 10, 20, 0, 20])], + ) + ocr = OCRResult(result=OCRDetails(text="hello", pages=[page])) + dumped = ocr.model_dump() + assert dumped["result"]["pages"][0]["lines"][0]["text"] == "hello" + assert dumped["result"]["pages"][0]["lines"][0]["bounding_box"] == [ + 0, + 0, + 10, + 0, + 10, + 20, + 0, + 20, + ] + + +class TestProvenance: + def test_field_provenance_new_flags(self) -> None: + # The MVP adds `provenance_verified` + `text_agreement` on top of the + # reference spec. Both are nullable bool. + fp = FieldProvenance( + field_name="bank_name", + field_path="result.bank_name", + value="UBS AG", + sources=[ + ExtractionSource( + page_number=1, + file_index=0, + bounding_box=BoundingBox(coordinates=[0.1, 0.1, 0.9, 0.1, 0.9, 0.2, 0.1, 0.2]), + text_snippet="UBS AG", + segment_id="p1_l0", + ) + ], + provenance_verified=True, + text_agreement=None, + ) + assert fp.provenance_verified is True + assert fp.text_agreement is None + + def test_field_provenance_flags_default_to_none(self) -> None: + fp = FieldProvenance(field_name="x", field_path="result.x") + assert fp.provenance_verified is None + assert fp.text_agreement is None + + def test_quality_metrics_accepts_all_keys(self) -> None: + # quality_metrics is a free-form dict; we just check the MVP-listed keys + # all round-trip as written. + prov = ProvenanceData( + fields={}, + quality_metrics={ + "fields_with_provenance": 8, + "total_fields": 10, + "coverage_rate": 0.8, + "invalid_references": 2, + "verified_fields": 6, + "text_agreement_fields": 5, + }, + ) + rt = ProvenanceData.model_validate(prov.model_dump()) + assert rt.quality_metrics["verified_fields"] == 6 + assert rt.quality_metrics["text_agreement_fields"] == 5 + assert rt.quality_metrics["coverage_rate"] == 0.8 + + def test_segment_citation_basic(self) -> None: + sc = SegmentCitation( + field_path="result.invoice_number", + value_segment_ids=["p1_l4"], + context_segment_ids=["p1_l3"], + ) + assert sc.value_segment_ids == ["p1_l4"] + + +class TestResponseIX: + def test_defaults(self) -> None: + r = ResponseIX() + assert r.error is None + assert r.warning == [] + assert isinstance(r.ix_result, IXResult) + assert isinstance(r.ocr_result, OCRResult) + assert isinstance(r.metadata, Metadata) + assert r.provenance is None + assert r.context is None + + def test_context_excluded_from_dump(self) -> None: + # ResponseIX.context is INTERNAL — must never show up in serialised form. + r = ResponseIX() + # Push something into context via the internal model. + from ix.contracts.response import _InternalContext + + r.context = _InternalContext(texts=["scratch"]) + dumped = r.model_dump() + assert "context" not in dumped + + dumped_json = r.model_dump_json() + assert "context" not in dumped_json + assert '"texts"' not in dumped_json # was only inside context + + def test_full_roundtrip_preserves_public_shape(self) -> None: + r = ResponseIX( + use_case="bank_statement_header", + use_case_name="Bank Statement Header", + ix_client_id="mammon", + request_id="req-1", + ix_id="abc123def4567890", + ix_result=IXResult(result={"bank_name": "UBS"}), + ocr_result=OCRResult(result=OCRDetails(text="UBS", pages=[])), + provenance=ProvenanceData( + fields={ + "result.bank_name": FieldProvenance( + field_name="bank_name", + field_path="result.bank_name", + value="UBS", + provenance_verified=True, + text_agreement=True, + ) + }, + quality_metrics={"verified_fields": 1, "text_agreement_fields": 1}, + ), + metadata=Metadata(timings=[{"step": "SetupStep", "seconds": 0.01}]), + ) + dumped = r.model_dump() + rt = ResponseIX.model_validate(dumped) + assert rt.provenance is not None + assert rt.provenance.fields["result.bank_name"].provenance_verified is True + assert rt.metadata.timings[0]["step"] == "SetupStep" + + +class TestJob: + def test_basic_construction(self) -> None: + req = RequestIX( + use_case="bank_statement_header", + ix_client_id="mammon", + request_id="r1", + context=Context(files=["file:///x.pdf"]), + ) + job = Job( + job_id=uuid4(), + ix_id="abcd1234abcd1234", + client_id="mammon", + request_id="r1", + status="pending", + request=req, + created_at=datetime.now(UTC), + ) + assert job.status == "pending" + assert job.callback_status is None + assert job.attempts == 0 + + def test_invalid_status_rejected(self) -> None: + req = RequestIX( + use_case="bank_statement_header", + ix_client_id="mammon", + request_id="r1", + context=Context(files=["file:///x.pdf"]), + ) + with pytest.raises(ValidationError): + Job( + job_id=uuid4(), + ix_id="abcd", + client_id="mammon", + request_id="r1", + status="weird", # not in the Literal + request=req, + created_at=datetime.now(UTC), + ) + + def test_full_terminal_done(self) -> None: + req = RequestIX( + use_case="bank_statement_header", + ix_client_id="mammon", + request_id="r1", + context=Context(files=["file:///x.pdf"]), + ) + resp = ResponseIX(use_case="bank_statement_header") + job = Job( + job_id=uuid4(), + ix_id="abcd1234abcd1234", + client_id="mammon", + request_id="r1", + status="done", + request=req, + response=resp, + callback_url="https://cb", + callback_status="delivered", + attempts=1, + created_at=datetime.now(UTC), + started_at=datetime.now(UTC), + finished_at=datetime.now(UTC), + ) + dumped = job.model_dump() + # Context must not appear anywhere in the serialised job. + assert "context" not in dumped["response"]