All checks were successful
tests / test (push) Successful in 1m44s
Co-authored-by: Dirk Riemann <ditori@gmail.com> Co-committed-by: Dirk Riemann <ditori@gmail.com>
435 lines
15 KiB
Python
435 lines
15 KiB
Python
"""Pydantic contracts — RequestIX and its nested option structures (spec §3)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import UTC, datetime
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
from pydantic import ValidationError
|
|
|
|
from ix.contracts import (
|
|
BoundingBox,
|
|
Context,
|
|
ExtractionSource,
|
|
FieldProvenance,
|
|
FileRef,
|
|
GenAIOptions,
|
|
IXResult,
|
|
Job,
|
|
Line,
|
|
Metadata,
|
|
OCRDetails,
|
|
OCROptions,
|
|
OCRResult,
|
|
Options,
|
|
Page,
|
|
ProvenanceData,
|
|
ProvenanceOptions,
|
|
RequestIX,
|
|
ResponseIX,
|
|
SegmentCitation,
|
|
)
|
|
from ix.contracts.request import InlineUseCase, UseCaseFieldDef
|
|
|
|
|
|
class TestFileRef:
|
|
def test_minimal(self) -> None:
|
|
fr = FileRef(url="https://example.com/x.pdf")
|
|
assert fr.url == "https://example.com/x.pdf"
|
|
assert fr.headers == {}
|
|
assert fr.max_bytes is None
|
|
|
|
def test_with_headers_and_max_bytes(self) -> None:
|
|
fr = FileRef(
|
|
url="https://paperless/x.pdf",
|
|
headers={"Authorization": "Token abc"},
|
|
max_bytes=1_000_000,
|
|
)
|
|
assert fr.headers == {"Authorization": "Token abc"}
|
|
assert fr.max_bytes == 1_000_000
|
|
|
|
def test_display_name_defaults_to_none(self) -> None:
|
|
fr = FileRef(url="file:///tmp/ix/ui/abc.pdf")
|
|
assert fr.display_name is None
|
|
|
|
def test_display_name_roundtrip(self) -> None:
|
|
fr = FileRef(
|
|
url="file:///tmp/ix/ui/abc.pdf",
|
|
display_name="my statement.pdf",
|
|
)
|
|
assert fr.display_name == "my statement.pdf"
|
|
dumped = fr.model_dump_json()
|
|
rt = FileRef.model_validate_json(dumped)
|
|
assert rt.display_name == "my statement.pdf"
|
|
# Backward-compat: a serialised FileRef without display_name still
|
|
# validates cleanly (older stored jobs predate the field).
|
|
legacy = FileRef.model_validate({"url": "file:///x.pdf"})
|
|
assert legacy.display_name is None
|
|
|
|
|
|
class TestOptionDefaults:
|
|
def test_ocr_defaults_match_spec(self) -> None:
|
|
o = OCROptions()
|
|
assert o.use_ocr is True
|
|
assert o.ocr_only is False
|
|
assert o.include_ocr_text is False
|
|
assert o.include_geometries is False
|
|
assert o.service == "surya"
|
|
|
|
def test_genai_defaults_match_spec(self) -> None:
|
|
g = GenAIOptions()
|
|
assert g.gen_ai_model_name is None
|
|
|
|
def test_provenance_defaults_match_spec(self) -> None:
|
|
p = ProvenanceOptions()
|
|
assert p.include_provenance is True
|
|
assert p.max_sources_per_field == 10
|
|
|
|
def test_options_default_nests_each_block(self) -> None:
|
|
opts = Options()
|
|
assert isinstance(opts.ocr, OCROptions)
|
|
assert isinstance(opts.gen_ai, GenAIOptions)
|
|
assert isinstance(opts.provenance, ProvenanceOptions)
|
|
|
|
|
|
class TestContextFiles:
|
|
def test_plain_string_entry_preserved_as_str(self) -> None:
|
|
ctx = Context(files=["https://example.com/a.pdf"])
|
|
assert ctx.files == ["https://example.com/a.pdf"]
|
|
assert isinstance(ctx.files[0], str)
|
|
|
|
def test_dict_entry_parsed_as_fileref(self) -> None:
|
|
ctx = Context(files=[{"url": "https://x/a.pdf", "headers": {"H": "v"}}])
|
|
assert len(ctx.files) == 1
|
|
entry = ctx.files[0]
|
|
assert isinstance(entry, FileRef)
|
|
assert entry.url == "https://x/a.pdf"
|
|
assert entry.headers == {"H": "v"}
|
|
|
|
def test_mixed_entries(self) -> None:
|
|
ctx = Context(
|
|
files=[
|
|
"file:///tmp/x.pdf",
|
|
{"url": "https://paperless/y.pdf", "headers": {"Authorization": "Token t"}},
|
|
],
|
|
texts=["extra ocr text"],
|
|
)
|
|
assert isinstance(ctx.files[0], str)
|
|
assert isinstance(ctx.files[1], FileRef)
|
|
assert ctx.texts == ["extra ocr text"]
|
|
|
|
def test_empty_defaults(self) -> None:
|
|
ctx = Context()
|
|
assert ctx.files == []
|
|
assert ctx.texts == []
|
|
|
|
|
|
class TestRequestIX:
|
|
def _minimal_payload(self) -> dict:
|
|
return {
|
|
"use_case": "bank_statement_header",
|
|
"ix_client_id": "mammon",
|
|
"request_id": "req-1",
|
|
"context": {"files": ["https://example/x.pdf"]},
|
|
}
|
|
|
|
def test_minimal_valid(self) -> None:
|
|
r = RequestIX(**self._minimal_payload())
|
|
assert r.use_case == "bank_statement_header"
|
|
assert r.ix_id is None
|
|
assert r.callback_url is None
|
|
assert r.options.provenance.include_provenance is True
|
|
|
|
def test_roundtrip_json_mixed_files(self) -> None:
|
|
payload = {
|
|
"use_case": "bank_statement_header",
|
|
"ix_client_id": "mammon",
|
|
"request_id": "req-42",
|
|
"context": {
|
|
"files": [
|
|
"file:///tmp/x.pdf",
|
|
{
|
|
"url": "https://paperless/y.pdf",
|
|
"headers": {"Authorization": "Token t"},
|
|
"max_bytes": 2_000_000,
|
|
},
|
|
],
|
|
"texts": ["paperless ocr text"],
|
|
},
|
|
"options": {
|
|
"ocr": {"include_ocr_text": True},
|
|
"gen_ai": {"gen_ai_model_name": "gpt-oss:20b"},
|
|
"provenance": {"max_sources_per_field": 5},
|
|
},
|
|
"callback_url": "https://mammon/ix-callback",
|
|
}
|
|
r = RequestIX.model_validate(payload)
|
|
assert isinstance(r.context.files[0], str)
|
|
assert isinstance(r.context.files[1], FileRef)
|
|
assert r.context.files[1].headers == {"Authorization": "Token t"}
|
|
assert r.options.ocr.include_ocr_text is True
|
|
assert r.options.gen_ai.gen_ai_model_name == "gpt-oss:20b"
|
|
assert r.options.provenance.max_sources_per_field == 5
|
|
assert r.callback_url == "https://mammon/ix-callback"
|
|
|
|
# Round-trip through JSON and back: FileRef dicts survive as FileRef.
|
|
dumped = r.model_dump_json()
|
|
r2 = RequestIX.model_validate_json(dumped)
|
|
assert isinstance(r2.context.files[1], FileRef)
|
|
assert r2.context.files[1].headers == {"Authorization": "Token t"}
|
|
|
|
# dumped JSON is valid JSON
|
|
json.loads(dumped)
|
|
|
|
def test_unknown_fields_rejected(self) -> None:
|
|
payload = self._minimal_payload()
|
|
payload["not_a_field"] = "x"
|
|
with pytest.raises(ValidationError):
|
|
RequestIX.model_validate(payload)
|
|
|
|
def test_ix_id_optional_and_documented(self) -> None:
|
|
# The docstring contract: caller MUST NOT set; transport assigns.
|
|
# Here we only assert the field exists and defaults to None — the
|
|
# "MUST NOT set" is a convention enforced at the transport layer.
|
|
r = RequestIX(**self._minimal_payload())
|
|
assert r.ix_id is None
|
|
assert "transport" in RequestIX.__doc__.lower() or "MUST NOT" in (RequestIX.__doc__ or "")
|
|
|
|
def test_missing_required_fields(self) -> None:
|
|
with pytest.raises(ValidationError):
|
|
RequestIX.model_validate({"use_case": "x"})
|
|
|
|
def test_use_case_inline_defaults_to_none(self) -> None:
|
|
r = RequestIX(**self._minimal_payload())
|
|
assert r.use_case_inline is None
|
|
|
|
def test_use_case_inline_roundtrip(self) -> None:
|
|
payload = self._minimal_payload()
|
|
payload["use_case_inline"] = {
|
|
"use_case_name": "adhoc",
|
|
"system_prompt": "extract stuff",
|
|
"fields": [
|
|
{"name": "a", "type": "str", "required": True},
|
|
{"name": "b", "type": "int"},
|
|
],
|
|
}
|
|
r = RequestIX.model_validate(payload)
|
|
assert r.use_case_inline is not None
|
|
assert isinstance(r.use_case_inline, InlineUseCase)
|
|
assert r.use_case_inline.use_case_name == "adhoc"
|
|
assert len(r.use_case_inline.fields) == 2
|
|
assert isinstance(r.use_case_inline.fields[0], UseCaseFieldDef)
|
|
# Round-trip through JSON
|
|
dumped = r.model_dump_json()
|
|
r2 = RequestIX.model_validate_json(dumped)
|
|
assert r2.use_case_inline is not None
|
|
assert r2.use_case_inline.fields[1].type == "int"
|
|
|
|
|
|
class TestOCRResult:
|
|
def test_minimal_defaults(self) -> None:
|
|
result = OCRResult()
|
|
assert result.result.text is None
|
|
assert result.result.pages == []
|
|
assert result.meta_data == {}
|
|
|
|
def test_full_page_roundtrip(self) -> None:
|
|
page = Page(
|
|
page_no=1,
|
|
width=612.0,
|
|
height=792.0,
|
|
lines=[Line(text="hello", bounding_box=[0, 0, 10, 0, 10, 20, 0, 20])],
|
|
)
|
|
ocr = OCRResult(result=OCRDetails(text="hello", pages=[page]))
|
|
dumped = ocr.model_dump()
|
|
assert dumped["result"]["pages"][0]["lines"][0]["text"] == "hello"
|
|
assert dumped["result"]["pages"][0]["lines"][0]["bounding_box"] == [
|
|
0,
|
|
0,
|
|
10,
|
|
0,
|
|
10,
|
|
20,
|
|
0,
|
|
20,
|
|
]
|
|
|
|
|
|
class TestProvenance:
|
|
def test_field_provenance_new_flags(self) -> None:
|
|
# The MVP adds `provenance_verified` + `text_agreement` on top of the
|
|
# reference spec. Both are nullable bool.
|
|
fp = FieldProvenance(
|
|
field_name="bank_name",
|
|
field_path="result.bank_name",
|
|
value="UBS AG",
|
|
sources=[
|
|
ExtractionSource(
|
|
page_number=1,
|
|
file_index=0,
|
|
bounding_box=BoundingBox(coordinates=[0.1, 0.1, 0.9, 0.1, 0.9, 0.2, 0.1, 0.2]),
|
|
text_snippet="UBS AG",
|
|
segment_id="p1_l0",
|
|
)
|
|
],
|
|
provenance_verified=True,
|
|
text_agreement=None,
|
|
)
|
|
assert fp.provenance_verified is True
|
|
assert fp.text_agreement is None
|
|
|
|
def test_field_provenance_flags_default_to_none(self) -> None:
|
|
fp = FieldProvenance(field_name="x", field_path="result.x")
|
|
assert fp.provenance_verified is None
|
|
assert fp.text_agreement is None
|
|
|
|
def test_quality_metrics_accepts_all_keys(self) -> None:
|
|
# quality_metrics is a free-form dict; we just check the MVP-listed keys
|
|
# all round-trip as written.
|
|
prov = ProvenanceData(
|
|
fields={},
|
|
quality_metrics={
|
|
"fields_with_provenance": 8,
|
|
"total_fields": 10,
|
|
"coverage_rate": 0.8,
|
|
"invalid_references": 2,
|
|
"verified_fields": 6,
|
|
"text_agreement_fields": 5,
|
|
},
|
|
)
|
|
rt = ProvenanceData.model_validate(prov.model_dump())
|
|
assert rt.quality_metrics["verified_fields"] == 6
|
|
assert rt.quality_metrics["text_agreement_fields"] == 5
|
|
assert rt.quality_metrics["coverage_rate"] == 0.8
|
|
|
|
def test_segment_citation_basic(self) -> None:
|
|
sc = SegmentCitation(
|
|
field_path="result.invoice_number",
|
|
value_segment_ids=["p1_l4"],
|
|
context_segment_ids=["p1_l3"],
|
|
)
|
|
assert sc.value_segment_ids == ["p1_l4"]
|
|
|
|
|
|
class TestResponseIX:
|
|
def test_defaults(self) -> None:
|
|
r = ResponseIX()
|
|
assert r.error is None
|
|
assert r.warning == []
|
|
assert isinstance(r.ix_result, IXResult)
|
|
assert isinstance(r.ocr_result, OCRResult)
|
|
assert isinstance(r.metadata, Metadata)
|
|
assert r.provenance is None
|
|
assert r.context is None
|
|
|
|
def test_context_excluded_from_dump(self) -> None:
|
|
# ResponseIX.context is INTERNAL — must never show up in serialised form.
|
|
r = ResponseIX()
|
|
# Push something into context via the internal model.
|
|
from ix.contracts.response import _InternalContext
|
|
|
|
r.context = _InternalContext(texts=["scratch"])
|
|
dumped = r.model_dump()
|
|
assert "context" not in dumped
|
|
|
|
dumped_json = r.model_dump_json()
|
|
assert "context" not in dumped_json
|
|
assert '"texts"' not in dumped_json # was only inside context
|
|
|
|
def test_full_roundtrip_preserves_public_shape(self) -> None:
|
|
r = ResponseIX(
|
|
use_case="bank_statement_header",
|
|
use_case_name="Bank Statement Header",
|
|
ix_client_id="mammon",
|
|
request_id="req-1",
|
|
ix_id="abc123def4567890",
|
|
ix_result=IXResult(result={"bank_name": "UBS"}),
|
|
ocr_result=OCRResult(result=OCRDetails(text="UBS", pages=[])),
|
|
provenance=ProvenanceData(
|
|
fields={
|
|
"result.bank_name": FieldProvenance(
|
|
field_name="bank_name",
|
|
field_path="result.bank_name",
|
|
value="UBS",
|
|
provenance_verified=True,
|
|
text_agreement=True,
|
|
)
|
|
},
|
|
quality_metrics={"verified_fields": 1, "text_agreement_fields": 1},
|
|
),
|
|
metadata=Metadata(timings=[{"step": "SetupStep", "seconds": 0.01}]),
|
|
)
|
|
dumped = r.model_dump()
|
|
rt = ResponseIX.model_validate(dumped)
|
|
assert rt.provenance is not None
|
|
assert rt.provenance.fields["result.bank_name"].provenance_verified is True
|
|
assert rt.metadata.timings[0]["step"] == "SetupStep"
|
|
|
|
|
|
class TestJob:
|
|
def test_basic_construction(self) -> None:
|
|
req = RequestIX(
|
|
use_case="bank_statement_header",
|
|
ix_client_id="mammon",
|
|
request_id="r1",
|
|
context=Context(files=["file:///x.pdf"]),
|
|
)
|
|
job = Job(
|
|
job_id=uuid4(),
|
|
ix_id="abcd1234abcd1234",
|
|
client_id="mammon",
|
|
request_id="r1",
|
|
status="pending",
|
|
request=req,
|
|
created_at=datetime.now(UTC),
|
|
)
|
|
assert job.status == "pending"
|
|
assert job.callback_status is None
|
|
assert job.attempts == 0
|
|
|
|
def test_invalid_status_rejected(self) -> None:
|
|
req = RequestIX(
|
|
use_case="bank_statement_header",
|
|
ix_client_id="mammon",
|
|
request_id="r1",
|
|
context=Context(files=["file:///x.pdf"]),
|
|
)
|
|
with pytest.raises(ValidationError):
|
|
Job(
|
|
job_id=uuid4(),
|
|
ix_id="abcd",
|
|
client_id="mammon",
|
|
request_id="r1",
|
|
status="weird", # not in the Literal
|
|
request=req,
|
|
created_at=datetime.now(UTC),
|
|
)
|
|
|
|
def test_full_terminal_done(self) -> None:
|
|
req = RequestIX(
|
|
use_case="bank_statement_header",
|
|
ix_client_id="mammon",
|
|
request_id="r1",
|
|
context=Context(files=["file:///x.pdf"]),
|
|
)
|
|
resp = ResponseIX(use_case="bank_statement_header")
|
|
job = Job(
|
|
job_id=uuid4(),
|
|
ix_id="abcd1234abcd1234",
|
|
client_id="mammon",
|
|
request_id="r1",
|
|
status="done",
|
|
request=req,
|
|
response=resp,
|
|
callback_url="https://cb",
|
|
callback_status="delivered",
|
|
attempts=1,
|
|
created_at=datetime.now(UTC),
|
|
started_at=datetime.now(UTC),
|
|
finished_at=datetime.now(UTC),
|
|
)
|
|
dumped = job.model_dump()
|
|
# Context must not appear anywhere in the serialised job.
|
|
assert "context" not in dumped["response"]
|